薄いブログ

技術の雑多なことを書く場所

go-openapi/loadsを速くした話

github.com

GitHub - go-openapi/loads: openapi specification object model の速度を改善したときの話をしたいと思います.

go-openapi/loadsはgoでopenapiのschemaを読み込むためのライブラリです.

調査の結果json.Unmarshalが重く,このライブラリのAnalyzedという関数から呼び出されていることがわかっていました.

実際に見てみると以下のような実装でした.

// Analyzed creates a new analyzed spec document
func Analyzed(data json.RawMessage, version string) (*Document, error) {
    if version == "" {
        version = "2.0"
    }
    if version != "2.0" {
        return nil, fmt.Errorf("spec version %q is not supported", version)
    }

    raw := data
    trimmed := bytes.TrimSpace(data)
    if len(trimmed) > 0 {
        if trimmed[0] != '{' && trimmed[0] != '[' {
            yml, err := swag.BytesToYAMLDoc(trimmed)
            if err != nil {
                return nil, fmt.Errorf("analyzed: %v", err)
            }
            d, err := swag.YAMLToJSON(yml)
            if err != nil {
                return nil, fmt.Errorf("analyzed: %v", err)
            }
            raw = d
        }
    }

    swspec := new(spec.Swagger)
    if err := json.Unmarshal(raw, swspec); err != nil {
        return nil, err
    }

    origsqspec := new(spec.Swagger)
    if err := json.Unmarshal(raw, origsqspec); err != nil {
        return nil, err
    }

    d := &Document{
        Analyzer: analysis.New(swspec),
        schema:   spec.MustLoadSwagger20Schema(),
        spec:     swspec,
        raw:      raw,
        origSpec: origsqspec,
    }
    return d, nil
}

パッと見ると遅いjson.Unmarshalが2回同じ対象に実行されていることがわかります.

詳細な背景はわかりませんが, どうやら一つはプログラム中で変更してしまうもので元のデータを保持しておきたいためのものだとわかります.

つまり, 2回目のjson.UnmarshalはjsonをUnmarshalしたい目的ではなくswspecと同等のものを用意したい, deep copyがしたいという意図でした.

json.Unmarshalのコストが低ければこれでも問題ないと思うのですが, spec.Swaggerはjson.Unmarshal時にOpenAPIの仕様を満たすために独自のjson.Unmarshalerが実装されています.

どうなっているか確認してみました.

spec/swagger.go at 93213dab6b424cc2dd3fe3aab33e6c5660aa3343 · go-openapi/spec · GitHub

func (s *Swagger) UnmarshalJSON(data []byte) error {
    var sw Swagger
    if err := json.Unmarshal(data, &sw.SwaggerProps); err != nil {
        return err
    }
    if err := json.Unmarshal(data, &sw.VendorExtensible); err != nil {
        return err
    }
    *s = sw
    return nil
}

まず最初にswaggerが満たすべき最低条件のためにsw.SwaggerPropsをjson.Unmarshalし, X-から始まる拡張用のフィールドのために再度json.Unmarshalする実装になっています.

それ以外にもspec.Swaggerのjson.UnmarshalはX-から始まる拡張用フィールドをいたるところでサポートしているので普通のものと比べて重い処理であることがわかります.

一度jsonを解釈してgoの構造体に落としているのだから, deep copyのためだけに再度jsonを解釈するコストが無駄なように思えます.

確かに構造体をポインタなども含めてちゃんとdeep copyをするコードを書くのは面倒くさいのでやりたくない気持ちは理解できます.

Golang Benchmark: gob vs json · GitHub

をみてgobにencodeしてからdecodeするほうがjson.Unmarshalによるdeep copyより速いのではないかと思いました.

実際に実装して, Benchmarkのコードを書いてみると2倍くらい速度の改善が見られました.

しかし,どんな場合でもjson.Unmarshalよりgobのほうが優れているということはないはずなのでちゃんと計測を行いましょう.

そしてgobを実際に使用する場合, 事前にgob.Registerで登録していないと[]interface{}map[string]interface{}がうまく扱えないので注意してください.

また,encodeしてdecodeする形式は一時的にメモリを使用するという点とちゃんと実装されていなければ公開されていないフィールドは扱えないのでそこに注意する必要があります.

ちなみにPRの本文に書いてあるようなBenchmark間の差を見たい場合は

perf/cmd/benchstat at master · golang/perf · GitHub

のようなツールが非常に便利です. ぜひ使ってみてください.

最後に

このようなケースにおいて, 高速なdeep copyを実現する場合にもっと優れている方法を知っている方は教えていただけると幸いです.

もしくはPRを送ってもらってもいいと思います.

PythonでUnicodeDecodeError/UnicodeEncodeErrorが出たときの原因調査法

Python2を使っているマルチバイト圏の人間なら一度は遭遇したことがあるであろうUnicodeDecodeError.

スタックトレースに出る情報は

UnicodeDecodeError: 'ascii' codec can't decode byte 0xe3 in position 0: ordinal not in range(128)

こういう出力で正直どこで失敗したのかどの変数がどの関数で処理されたときに起こったものなのかが分かりづらい.

これまでは適当にあたりをつけて.decode("utf-8")だったりu"こんにちは"のような感じで試してお茶を濁すことが多かった.

無意識のうちにUnicodeDecodeErrorが持っている情報というのを無視していた.

UnicodeDecodeErrorがどんな情報を持っているのかを説明している文章をあまり見かけないので書いてみようと思った.

6. Built-in Exceptions — Python 2.7.15 documentation

5. Built-in Exceptions — Python 3.7.0 documentation

ここにすべてが書いてある. 公式ドキュメントをちゃんと調査する人間しかたどり着けないようになっている.

python UnicodeDecodeErrorと調べると詰まっている日本人が山程いるし, 場当たり的な解決策を提示した記事が上位にわんさか出てくる.

少しでも時間に余裕があるときは公式ドキュメントを見るようにしたほうが良い.

それはともかくとしてencoding, object, reason, start, end の情報が得られる.

僕自身はutf-8を扱うことが多いので以下のようなコードで調査する.

try:
    raise_unicode_decode_error_function("こんにちは")  # ここはUnicodeDecodeErrorを起こす関数
except UnicodeDecodeError as e:
    print(e.object.decode("utf-8"))

これで対象のデータが分かる.

後はそのデータがどの変数に入っているものなのかどこから来たものなのかの調査ができればあとは適切にencode/decodeするだけだ.

余談

のversion 1.0.0のリリースを作業をしていてCIを追加したり, テストを追加していた際にPython2でのみ

UnicodeDecodeErrorが発生しテストが落ちていたため, 原因調査の際に用いた手法をまとめようと思い書いた記事です.

完全に想定外だったのはevalの中の文字列リテラルがUnicodeDecodeErrorを起こしていたところです.

このへんをPython 2/3 compatibleに書くのってどうしたらいいんでしょうか. 教えてください.

最後にですがembexprをstarしてくれると嬉しいです.

Wiener's Attack を実装した

はじめに

このスライドをみてそういえばWiener’s Attack実装したこと無いなと思ったので勉強がてら実装してみました.

元論文はCryptanalysis of Short RSA Secret Exponents
理論についての説明は
http://elliptic-shiho.hatenablog.com/entry/2015/12/18/205804
が詳しいです.
ここでは実装について書いていこうと思います.

実装について

Wiener’s Attackは連分数展開と連分数から近似分数を求める2つのパートからなります.
連分数展開に関してはユークリッドの互除法の過程で得られるので素直に実装します.
ユークリッドの互除法の計算量は$O(\log_{10} n)$

def rational_to_contfrac(x: int, y: int) -> Iterator[int]:
    while y:
        a = x // y
        yield a
        x, y = y, x - a * y

連分数から近似分数を求めるパートがあるのですが,まず連分数から有理数を復元するところから説明します.
連分数を後ろから足して逆数求めることを繰り返していけば復元できます.

def contfrac_to_rational(contfrac: List[int]) -> Tuple[int, int]:
    from functools import reduce
    return reduce(
        lambda f,q: (q * f[0] + f[1], f[0]),
        reversed(contfrac),
        (1, 0),
    )

ちなみに上記のコードは使ってないのでrepositoryにはありません.
前から復元する場合は元論文の式6を用いるとできます.実際にはこちらを使います.

$$ \begin{alignat}{2} n_0 & = & q_0, \qquad d_0 & = & 1 \\ n_1 & = & q_0 q_1 + 1, \qquad d_1 & = & q_1 \\ n_i & = & q_i n_{i-1}+n_{i-2}, \qquad d_i & = & q_i d_{i-1}+d_{i-2} \\ \end{alignat} $$

式からもわかる様に前の2つの状態を持っていれば次の状態を作ることができます.
このままだと$i$が$0$と$1$の時条件分岐しないといけません.
$-1$と$-2$の時を条件を満たす様に定義することで実装がシンプルになります.

$$ \begin{alignat}{2} n_{-2} & = & 0, \qquad d_{-2} & = & 1 \\ n_{-1} & = & 1, \qquad d_{-1} & = & 0 \\ n_i & = & q_i n_{i-1}+n_{i-2}, \qquad d_{i} & = & q_i d_{i-1}+d_{i-2} \\ \end{alignat} $$

def contfrac_to_rational_iter(contfrac: Iterable[int]) -> Iterator[Tuple[int, int]]:
    n0, d0 = 0, 1
    n1, d1 = 1, 0
    for q in contfrac:
        n = q * n1 + n0
        d = q * d1 + d0
        yield n, d
        n0, d0 = n1, d1
        n1, d1 = n, d

計算量は$O(n)$になります.
用途を考えてgeneratorとして返しています.

近似分数を列挙するのですが元論文の3章のアルゴリズムを用います.
そんなに難しくないので読んでみることをおすすめします.
基本的には連分数を前の方だけ使って復元した値を近似分数とします.
愚直に実装すると以下の様になります.

def convergents_from_contfrac_naive(contfrac: List[int]) -> Iterator[Tuple[int, int]]:
    m = len(contfrac)
    for i in range(m):
        if i % 2 == 0:
            q = contfrac[:i+1]
            q[i] += 1
            yield contfrac_to_rational(q)
        else:
            q = contfrac[:i+1]
            yield contfrac_to_rational(q)

添字が偶数のときは最後の要素を1足した物を有理数に復元し,それを近似分数とします.
これは毎回contfrac_to_rationalを呼んでいて無駄です.
偶数のときの処理が面倒くさい様に見えますが,式を見てみると最後の要素に1を足すという処理は1つ前の結果を足すこと同様だと言うことがわかります.
つまり直前の結果だけ持っていれば良いです.
$i=0$のときは定義した$i=-1$のときの結果を用います.

def convergents_from_contfrac(contfrac: Iterable[int]) -> Iterator[Tuple[int, int]]:
    n_, d_ = 1, 0
    for i, (n, d) in enumerate(contfrac_to_rational_iter(contfrac)):
        if i % 2 == 0:
            yield n + n_, d + d_
        else:
            yield n, d
        n_, d_ = n, d

これまで紹介して来たコードを組み合わせてattack関数を定義します.

def attack(e: int, n: int) -> Optional[int]:
    f_ = rational_to_contfrac(e, n)
    for k, dg in convergents_from_contfrac(f_):
        edg = e * dg
        phi = edg // k

        x = n - phi + 1
        if x % 2 == 0 and is_perfect_square((x // 2) ** 2 - n):
            g = edg - phi * k
            return dg // g
    return None

近似分数を総当りでdを特定します.
正しい有理数に復元できたかどうかの判定は元論文のとおりに実装しました.
平方数の判定は$O(log_2n)$程度でできるので,全体の計算量は$O((log_2n)^{2})$になります.
比較的簡単にWiener’s Attackを実装できました.
証明やどういう場合に使えるかについては元論文を読んで見てください.

おまけ

平方数の判定は2分探索を用いることで出来ますが,今回はgmpの実装を参考にしました.
平方数の$\mod256$は$44$パターンしか無いのでまずそれで篩にかけます.
全体の$80\%$程度を瞬時に判定することが出来ます.
残りの$20\%$程度に対して$\mod9,5,7,13,17$で同様のことを行います.
これにより全体の$96\%$に対して判定を行えます. 残りの$4\%$に対してニュートン法を用いて平方数判定を行います.

def is_perfect_square(n: int) -> bool:
    sq_mod256 = (1,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0)
    if sq_mod256[n & 0xff] == 0:
        return False

    mt = (
        (9, (1,1,0,0,1,0,0,1,0)),
        (5, (1,1,0,0,1)),
        (7, (1,1,1,0,1,0,0)),
        (13, (1,1,0,1,1,0,0,0,0,1,1,0,1)),
        (17, (1,1,1,0,1,0,0,0,1,1,0,0,0,1,0,1,1))
    )
    a = n % (9 * 5 * 7 * 13 * 17)
    if any(t[a % m] == 0 for m, t in mt):
        return False

    return isqrt(n) ** 2 == n

整数のみを扱うニュートン法を用いました.
ニュートン法は初期値を決めないといけませんが平方根を求める場合は簡単にある程度良い初期値が与えられます.
適切に$n$を決定すると$S \approx \alpha \cdot 2^{2n}$は$0.5 \leq \alpha \lt 2$になります.
$n$は$S$の最上位bitが立ってる場所を2で割ると求めることが出来ます.
$\sqrt{S}\approx\sqrt{\alpha}2^n$になるので初期値を$2^n$にすると正しい答えに近い地点にいるので収束が早くなります.

def isqrt(n: int) -> int:
    if n == 0:
        return 0

    x = 2 ** ((n.bit_length() + 1) // 2)
    while True:
        y = (x + n // x) // 2
        if y >= x:
            return x
        x = y

終わりに

Pythonなりにわかりやすく高速に実装できたかなと思います.
(JavaのBigIntegerを使って実装している例があったのですが読みづらくて辛かった) 実装してみて楽しかったので他の手法も論文を読みながら実装できるくらい頭良くなりたい.

参考文献

Cryptanalysis of Short RSA Secret Exponents:
https://www.cits.ruhr-uni-bochum.de/imperia/md/content/may/krypto2ss08/shortsecretexponents.pdf
公開鍵暗号 - RSA - Wiener’s Attack - ₍₍ (ง ˘ω˘ )ว ⁾⁾ < 暗号楽しいです:
http://elliptic-shiho.hatenablog.com/entry/2015/12/18/205804
pablocelayes/rsa-wiener-attack:
https://github.com/pablocelayes/rsa-wiener-attack
wihoho/Wiener-s-Attack:
https://github.com/wihoho/Wiener-s-Attack
平方数かどうかを高速に判定する方法 - hnwの日記
http://d.hatena.ne.jp/hnw/20140503
Integer square root:
https://en.wikipedia.org/wiki/Integer_square_root
Methods of computing square roots:
https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Rough_estimation

二辺連結成分分解

ゆらふなさんの記事を見て実装してみようと思ったので記事を書きました。
この記事ではlowlinkを使った方法でやっています。

* 2017年9月30日追記 *
二重辺と表記していましたが二辺と表現するほうが適切と判断したので変更しました.
それに付随して定義などの情報を追加しました.

定義

  • 橋 (Bridge): その一つの辺を取り除くとグラフが非連結になるような辺
  • K辺連結グラフ: Kより小さい数の辺を取り除いても連結であるグラフ
  • 二辺連結成分 (2-edge connected component): 2より小さい数(=1)の辺を取り除いても連結である部分グラフ -> 橋を含まない部分グラフ

ref: k-edge connected graph

内容

二辺連結成分は橋を含まないような部分グラフなので橋を列挙することで二辺連結成分分解ができることがわかる。
橋の検出、列挙に関しては次のpdfがわかりやすい。グラフ探索アルゴリズムとその応用
基本的には深さ優先探索をするだけで求まる。まずグラフに対して適当な頂点を決めそこから深さ優先探索を行う。

$ord[v]:$ 深さ優先探索でその頂点を何番目に訪れたか
$low[v]:$ 後退辺を一度だけ使ってたどり着ける$ord$の最小値 $(lowlink)$

とすると深さ優先探索木上の辺$uv$が橋であるかの判定は
$$ord[u] < low[v]$$ で行える。$ord$は深さ優先探索木の根方向に行けば行くほど小さくなるので$u$の子である$v$は$ord[u]<ord[v]$を満たす。
$ord[u] >= low[v]$の場合は$u$より根方向に後退辺を用いて移動でき, $u$より根方向の頂点からは$uv$を使わずに$u$に到達できるので
$uv$が橋でないことがわかる。逆に$ord[u]<low[v]$の場合は$v$から$u$より根方向に到達できないので橋であることがわかる。
一度の深さ優先探索で橋を列挙することができるので計算量は$O(n + m)$になる。

橋を列挙できれば橋の端点から橋となる辺を使わないように深さ優先探索で二辺連結成分を列挙できる。
ただ橋が存在しない場合は適当な頂点から深さ優先探索すると良い。
全体の計算量は$O(n + m)$になる。

実装は以下のようになった。

二辺連結成分分解

verifyにはARC039Dを用いた。submission

ARC039D

ICPC国内予選2016参加記

6/24(金)にあったICPC国内予選2016にICT48(@orisano,@ringoh72,@DeEn_queue)として出場しました。
結果は4完32位と去年より順位は下がってしまいとても悲しかったです。

大会前日

前日に今年用のライブラリを作ってないことに気が付き、去年のライブラリをベースに作成。
いろいろやってたら5時。1限もあるし寝れない。このまま挑むしかない。

大会当日

1限が終了してそのまま学校で就寝。3時間ほど寝てから今回国内予選初参加の
@ringoh72と@DeEn_queueの提出練習とプリンタのテストをして待機してました。

問題

A:全探索
僕がスピードのみを重視して O(n2) のコードを書いてAC。

B:問題を知らず
僕がC,Dを見ているうちに@ringoh72がサンプル合わないと言っていた。
C,Dの解法を考えているうちにACしていた。

C:篩っぽいことをする、最大ケースが与えられていなければ辛そう
@DeEn_queueが解法を考えてくれていたがブランクが長かったからなのか実装に悩んでいたので
僕が代わりにコードを書いてAC。

D:区間DP
僕が担当していて舐めプしてたら結構いろんなミスが見つかって辛かった。
解法自体は制約と見た感じからすぐわかったのだけどうまく実装できず。
そこそこ時間はかかったもののAC。

E:誤読
最初みて幾何かなと思い後回しにする。ちゃんと読むべきだった。
ちゃんと読んでも問題の意味を変に解釈してしまった。 おそらく一番時間を費やすべき問題だったと思うし反省。

F:外側から連結成分取り出して木を作って同型性判定
3人で考えてやるべきことはわかったのだが、「同型性判定ライブラリないからできない><」とか
甘えたことを言ってしまい断念。少し考えれば思いつきそうだし取り組むべきだった。

G:誤読
3人ともワープ場は飛行機の線路上に限らず任意の地点に置くことができると勘違いしてしまって断念。
clarができなくてはじめて困ってしまった。

H:わからず
フローのオーラを感じたが経験と知識不足から断念。

まとめ

二人とチームを組んで競プロするのは初めてだったので役割とかその辺がよくわからなかった。
チームで練習すべきだったなと痛感しました。そもそもチームが締め切り当日まで決まってなかったのがやばい。
アジア地区行けるといいな(3年連続になるといいな)

Git Challengeに参加してきました

3/5に開催されたmixiさん主催のgit challengeというイベントに参加してきました。

参加したきっかけ

@imishinist(たいさん)から「参加しないのですか???????」と煽られたのもありますが、
gitは頻繁に使うものの全然使いこなせてない感があり、そろそろスキルアップしたと思っていた矢先だったというのが大きいです。

感想

twitterで知っていた@KAGE_MIKUくんとチームで競技に取り組みました。
@KAGE_MIKUが話しやすい人でわからないところは聞くことができたし、shell芸っぽいところでは手伝うこともできたかなと思います。
結果として4位になり、嬉しさと悔しさがありました。
採点にトラブルがありHumanCIになっていたのが少し残念でしたが、それでもとても楽しめました。

競技終了後に解説があるのがとても良く、こんなコマンドあったんだ、こうやって解決するのか、などの気付きがありました。

こんな面白いイベントを開催してくれたmixiさんに感謝しかないです!
是非後輩たちは参加してみてください!!!!!!!!!!!!!!!!

学校とか部活とかでgit challengeのような取り組みをしてみるとgitの布教に使える気がするし作ってみたい感がある。

AOJ 0568 Pasta

AOJ0568Erlangで解いてみたのでBlogを書きます。
問題の概要については省略します。

前回と同様に JOI 2011-2012 予選 問題・採点用入力 から正答することを確認しています。

-module(aoj0568).
-export([main/0]).

main() ->
  {N, K} = get_pair(),
  Fixed = lists:reverse(lists:sort([get_pair() || _ <- lists:seq(1, K)])),
  io:format("~p~n", [pasta(N + 1, Fixed, {0, -1, -2})]),
  init:stop().

to_i(S) ->
  {I, _} = string:to_integer(S),
  I.

get_pair() ->
  [First, Second] = [to_i(V) || V <- string:tokens(io:get_line(""), " ")],
  {First, Second}.

pasta(N, [{N, P}|_], {Q, _, _}) when P /= Q -> 0;
pasta(N, [{N, P}|T], S={P, _, _}) -> pasta(N, T, S);
pasta(_, _, {P, P, P}) -> 0;
pasta(1, _, _) -> 1;
pasta(N, L, S={P, Q, _}) ->
  case get({N, S}) of
    undefined ->
      Ret = lists:foldl(fun(X, Acc) ->
                            Acc + pasta(N - 1, L, {X, P, Q})
                        end, 0, lists:seq(1, 3)) rem 10000,
      put({N, S}, Ret),
      Ret;
    Memo -> Memo
  end.

コードの概要

入力を受け取って、逆順にソートします。
逆順にソートする理由は引数を減らせるというだけの理由です。
基本的には全探索でパターンマッチにより遷移できない状態排除しています。
本来であれば過去2日分の情報を持つだけで遷移先を決定できるのですが、
過去3日分の情報を持つことでパターンマッチを容易にしています。
遷移する際に過去の情報をスライドさせるだけで良いのでとてもシンプルになります。

終わりに

やっぱりErlangだとプロセス辞書を使ってメモ化再帰が楽に綺麗にかけるので好きです。