狠狠撸
Submit Search
PyTorch, PixyzによるGenerative Query Networkの実装
?
Download as PPTX, PDF
?
1 like
?
587 views
S
Shohei Taniguchi
Follow
12/13(木)の顿尝贬补肠办蝉にて発表した骋蚕狈(生成クエリネットワーク)実装スライド
Read less
Read more
1 of 36
Download now
Download to read offline
More Related Content
PyTorch, PixyzによるGenerative Query Networkの実装
1.
PyTorch, PixyzによるGenerative Query
Networkの実装 1 松尾研B4 谷口尚平
2.
自己紹介 ? 名前:谷口尚平 ? 所属:東京大学松尾研究室 ?
学年:B4 ? 研究領域 – 深層生成モデル(世界モデル) – 強化学習のための状態表現学習 – マルチモーダル学習(言語?概念の創発) ? 今回は卒論に関連して進めていたGQNの 再現実装を共有させていただくため、参加 させていただきました。 2
3.
アウトライン 1. Generative Query
Networkとは – 概要 – Variational Autoencoder – DRAW – GQNのアーキテクチャ的な工夫点 – (余談) メタ学習としてのGQN 2. 実装紹介 – Pixyzの紹介 – Representation Network – Convolutional LSTM – Generation Network – その他 3
4.
アウトライン 1. Generative Query
Networkとは – 概要 – Variational Autoencoder – DRAW – GQNのアーキテクチャ的な工夫点 – (余談) メタ学習としてのGQN 2. 実装紹介 – Pixyzの紹介 – Representation Network – Convolutional LSTM – Generation Network – その他 4
5.
概要 一言で言うと ある視点からの観測が与えられたとき に、別の視点からの観測を予測する深 層生成モデル どうやって? Conditional DRAWを用いています DRAWとは? VAEと自己回帰モデルの合わせ技のよ うなものです 5 [DL輪読会]GQNと関連研究,世界モデルとの関係について /DeepLearningJP2016/dlgqn-111725780
6.
Variational Autoencoder ①入力から潜在変数 ?
への推論 ② ? から入力の再構成(生成) をニューラルネットでモデル化した深層生成モデル 尤度 ? ? を最大化したいが、直接は測れないので、 以下の変分下限を最大化するように学習する E ?(?|?) log ? ? ? ? ? ? ? ? = E ?(?|?) log ? ? ? ? DKL[? ? ? ||? ? ] 6 ? ?? Generation ?(?|?) Inference ?(?|?)
7.
Variational Autoencoder 直感的な理解の仕方 1. 潜在変数
? を平均 0, 分散 1の正規分布からサンプリングしてニューラルネット ワークに入れたら画像を生成するようにしたい 2. そこで中間表現を確率的な変数 ? にしたAutoencoderを用意して、画像の再構 成を学習させることにする – 再構成誤差はdecoder (generator) の出力をベルヌーイ分布の平均値と考え、その尤度を 計算する – ただし、カラー画像の場合は分散を固定した正規分布として尤度を計算することが多い (GQNもそう) 3. それだけだと ? になんの制約もないので、encoder (inference)が出力する ? が 平均 0, 分散 1に近づくようにKL項を用意して制約をかける ? 理論的には変分ベイズが背景にあるので、詳しく知りたい人は以下の元論文とレ ビュー論文を読むことをオススメします – Auto-Encoding Variational Bayes https://arxiv.org/abs/1312.6114 – Tutorial on Variational Autoencoders https://arxiv.org/abs/1606.05908 7
8.
Conditional VAE ? 通常のVAEでは事前分布
?(?) を平均 0, 分散 1 で固定していたが、それをラベル ? で 条件付けた ?(?|?) とし、ニューラルネットでモデル化することで、ラベルで条件付けた 生成を可能にしたもの e.g. 1の数字が書かれたMNIST画像を生成するなど ? この場合、最大化したいのは条件付き尤度 ? ?|? になるので、変分下限も以下のよう に変わる E ?(?|?,?) log ? ? ?, ? ? ?|? ? ? ?, ? = E ?(?|?,?) log ? ? ?, ? ? DKL[? ? ?, ? ||? ?|? ] 8 ?, ? ??, ? Generation ?(?|?, ?) Inference ?(?|?, ?) Prior ?(?|?) ? KL
9.
Conditional VAE 基本的にはGQNもこれがベースです 条件付けるラベル (?)
が与えられた観測と生成する画像の視点になっただけ – 与えられた観測 (context): M個の視点 ?? 1..? と観測 ?? 1..? のペア – 生成する画像の視点 (query): ?? ? これらで条件付けて、queryに対応する観測画像 ?? ? を生成する つまり、GQNは ?(? ?|? ?..?, ? ?..?, ? ?) をモデル化している ? 確率モデル的な枠組みはこれだけですが、GQNではアーキテクチャの面で 様々な工夫を行なっている – DRAWもその1つ 9
10.
DRAW ? Gregorらが提案したVAEと自己回帰モデルを組み合わせたモデル ? 確率モデル的にはVAEと同じ ?
VAEにおける潜在変数 ? への推論をRNNを用いて複数回に分けて自己回 帰的に行うことで、モデルの表現力を高めている ? ? ? = ?=1 ? ??(??|?, ?<?) ? 厳密にはDRAWを提案した論文では他にもAttentionを用いるなど様々な工 夫をしているが、一般的には VAEにおいて ? への推論を自己回帰的に行 う枠組みを総称してDRAWと呼ばれている(少なくともGQN系の論文ではそ ういう扱い) 10
11.
DRAW 変分下限 log ? ?
≥ E ? ? ? log ? ? ? ? DKL[? ? ? | ? ? = E ? ? ? log ?(?|?) ? DKL[ ?=1 ? ?? ?? ?, ?<? || ?=1 ? ?? ?? ?<? ] = E ? ? ? log ? ? ? ? ?=1 ? E ? ?<? ? [DKL ?? ?? ?, ?<? ?? ?? ?<? ? E ? ? ? log ? ? ? ? ?=1 ? DKL[??(??|?, ?<?)||??(??|?<?)] ? 分布の積のKLは各分布のKLの和でモンテカルロ近似できる 11
12.
DRAW アルゴリズム for ? =
1 to ? Prior Distribution ?? ?? ?<? = ? ? ?? ? ? Encoder RNN ? ? = ??? ??? ?, ??, ? ?, ? ? Posterior Sample ?? ~ ?? ?? ?, ?<? = ? ? ?? ? ? Decoder RNN ? ? = ??? ??? ??, ? ? KL Divergence DKL[?? ?? ?, ?<? ||?? ??|?<? ] Canvas ?? = ?? + ? ? ? ? Likelihood ?(?| ? ?) ? RNNを通して自己回帰的に潜在変数 ? を推論する ? 画像はそれぞれの ? から生成したものを重ね書きしていく (絵を描くのに似ているのでDRAWという名前が付いている) 12
13.
GQN contextとqueryで条件付けたConditional DRAW ? これまでの内容を踏まえると、GQNの変分下限は以下のようになる E
? ? ? ?, ?1..?, ?1..?, ? ? log ? ? ? ?, ?1..? , ?1..? , ? ? ? ?=1 ? DKL[?? ?? ? ? , ?1..? , ?1..? , ? ? ||??(??|?1..? , ?1..? , ? ? )] ? さらにGQNではcontextを圧縮するrepresentation networkを用意して、? = ?=1 ? ?(? ?, ? ?) としているため、以下のように簡潔化できる E ? ? ? ?, ?, ? ? log ? ? ? ?, ?, ? ? ? ?=1 ? DKL[?? ?? ? ? , ?, ? ? ||??(??|?, ? ? )] これがGQNの目的関数です 13
14.
GQN GQNのアルゴリズム for ? =
1 to ? Prior Distribution ?? ?? ?<? = ? ? ?? ? ? Encoder RNN ? ? = ??? ??? ? ?, ? ?, ?, ??, ? ?, ? ? Posterior Sample ?? ~ ?? ?? ? ?, ? ?, ?, ?<? = ? ? ?? ? ? Decoder RNN ? ? = ??? ??? ? ?, ?, ??, ? ? KL Divergence DKL[?? ?? ?, ?<? ||?? ??|?<? ] Canvas ?? = ?? + Δ ? ? Likelihood ?(? ? |? ? (? ?)) ? 基本的にはDRAWに条件付ける変数 (? ?, ?) が加わっただけ ? 最後に尤度をとる前にもう一度NNをかませているのが少し違うが、それ以外は 普通のConditional DRAW 14
15.
アーキテクチャ的な工夫点 ? Representation networkによって、context情報を圧縮して環境に関する事 前知識を決定論的な変数
? として学習させている – Neural Processes (Garnelo et al., 2018) でも提案されている手法 ?? = ?(? ?, ? ?) ? = ?=1 ? ?? – それぞれのcontext情報を合計している点が大きな特徴 – これはそれぞれのcontextが順序不変 (permutation invariant) であることを利用して いる (e.g. 得られる観測は見る視点の順番が変わっても変わらない) – もしこれをRNNにしてしまうと、後に得られた観測ほどcontextに大きな影響を与える ことになってしまう – ただし、合計することが最善なのかは議論の余地がある 15
16.
(余談) メタ学習としてのGQN GQNの問題設定は実はメタ学習と全く同じ – 任意のデータセット(ここではシーン)が与えられた時に入力(視点)から出力(観測) へ正しく写像できるように、すべてのデータセットで共有されるパラメータをメタ知識と して学習させる –
GQNではrepresentation networkがメタ知識を学習するパラメータを保持している ? NNによるメタ学習の主なアプローチは以下の2通り ① メタ知識をパラメータの初期値として学習させる (MAML) ② メタ知識を学習するためのネットワークを別に用意する – GQN (Neural Processes) は②に当たる ? permutation invariantな問題設定では、GQNのようなメタ学習的なアプロー チが有効であることを示したのも大きな貢献(世の中には同じようなアプ ローチで解ける問題がたくさん眠っているはず) 16
17.
アウトライン 1. Generative Query
Networkとは – 概要 – Variational Autoencoder – DRAW – GQNのアーキテクチャ的な工夫点 – (余談) メタ学習としてのGQN 2. 実装紹介 – Pixyzの紹介 – Representation Network – Convolutional LSTM – Generation Network – その他 17
18.
過去のGQN実装 TensorFlow https://github.com/ogroth/tf-gqn – 一番早く上がった実装 – 死ぬほどわかりづらい –
どうすればこんなにわかりづらく書けるのかわからない – ちゃんと動くっぽい Chainer https://github.com/musyoku/chainer-gqn – musyokuさんの実装 – わかりやすい – データセットを自作するコードも公開(さすがの実装力) – 結果のクオリティが高い PyTorch https://github.com/wohlert/generative-query-network-pytorch – めちゃくちゃわかりやすい – 学習が全くうまくいっていない → PyTorchだけまともに動くものがない 18
19.
Pixyz 弊研?鈴木雅大作のPyTorchベースの深層生成モデル用ライブラリ ? ネットワークを確率モデルで隠蔽する構造になっているため、尤度やKL Divergenceなど、確率分布間の操作をネットワーク構造を意識することなく 実装できるため、コードの可読性が上がる。 ? TensorFlow
Probability (Edward) などはネットワークと確率モデルを並列に 扱っているため、深層生成モデル用としてはやや使いづらい 19
20.
実装 ? 今回はPyTorch, Pixyzの両方で実装しました。 ?
基本的にはほぼ同じですが、両者を比較しながらPixyzの使い方も紹介して いきたいと思います。 注:現在のPixyzはLoss APIが自己回帰モデルに対応していないため、現状では Pixyzを使うメリットがあまりありません。近々、自己回帰に対応したバー ジョンがリリースされるので、そのタイミングでPixyz版のアップデートを 行う予定です。 ? 論文内では一部のハイパーパラメータに関する記述がなかったため、著者 のEslamiさんに確認をとりました。 (おそらく元のPyTorch実装が動かないのはここのハイパラが間違っている から) 20
21.
論文に記述のないハイパーパラメータ ? 基本的に変数の次元数に関する記述が全くないため著者に確認 – 潜在変数
? : batch_size x 3 x 16 x 16 – RNNの隠れ変数 ?, ? : batch_size x 128 x 16 x 16 – DRAWのcanvasとなる ? : batch_size x 128 x 64 x 64 – 潜在変数 z の次元数は学習に大きな影響を与えるようなので注意 (チャンネル数を大きくしすぎると全く学習が進みません) ? RNNへの入力のサイズを合わせるためのdown-sample, up-sampleのネット ワーク – すべてbias項なしの畳み込み or 逆畳み込みを1層用意し、サイズがRNNの隠れ変数 と同じになるようにする(channel数は変えない) e.g. 画像は、kernel_size: 4x4, stride: 4x4, padding: 0 の畳み込みをしてから RNNに入力する 21
22.
dataset/convert2torch.py ? DeepMindが公開しているデータセット https://github.com/deepmind/gqn-datasets ? 元データセットはTensorFlow用のものなので、PyTorch用に変換する ?
Shepard-Metzlerは約1日で変換できました ? To do: マルチプロセス化 22
23.
gqn_dataset.py データローダ周りのスクリプト sample_batch – 1つのデータには1つのシーンに相当する画像と視点のペア群が入っているため、学 習時にはそこからcontextに使うものとqueryに使うものをランダムにサンプリングして 使用する 23
24.
representation.py Representation networkの実装 24 論文で提案されている3種類を実装 Poolが一番いいらしい
25.
conv_lstm.py Coreで使うConvolutional LSTMの実装 25
26.
Core EncoderRNN, DecoderRNNに当たる部分 26 入力のサイズを合わせてConvLSTMに入れる
27.
Distributionクラス (Pixyz) 27 ネットワークを確率分布で隠蔽するクラス ここでは正規分布なので、平均 (loc)
と標準 偏差 (scale) を辞書形式で返す 初期化の際に、条件づける変数 (cond_var) と 変数 (var) をリストで渡す。 forwardの入力とcond_varを揃えることに注意
28.
model.py 28 PyTorch Pixyz *_coreが自己回帰の部分を担うConvolutional LSTM Pixyzではeta_*
の代わりに笔谤颈辞谤などの辫颈虫测锄.诲颈蝉迟谤颈产耻迟颈辞苍蝉クラスのインスタンスを立てる
29.
model.py 29 Pixyzではネットワークを 確率モデルで隠蔽している ため、q.sampleなどとする だけで分布からのサンプリ ングが可能で、コードが読 みやすくなる! PyTorch Pixyz
30.
train.py 30 負の変分下限をロス関数として学習する。 学習率と生成画像の分散のアニーリングはここで行 う。 TensorBoardXを用いてログを保存
31.
train.py オプション ? デフォルトでは、すべて論文通りの実装が動くようになっていますが、GPU のメモリが多くない場合のために、train.pyに以下のオプションがあります 1. --layers 自己回帰のループ数
(default: 12) 8くらいでも十分な結果が得られます 2. --shared_core 自己回帰のループ間で同じRNNを使うかどうか (default: False) これがパラメータをめちゃくちゃ増やしていますが、Trueでもほぼ同じ結果が得られます (論文内でもFalseだと少しだけ結果が良くなると書いてある) 31
32.
画像生成モデルの実装上の注意点 1. 最後の画像の出力を [0,
1] に抑える – 初めて画像生成系の実装をするときにやりがちなミス – 学習はできてるっぽいのに、生成画像に変なノイズが乗っているときは大抵これをし忘れて いる – モデルの出力は最後がsigmoid等になっていない限り [0, 1] になるとは限らないので、<0 or >1 の部分を0, 1に書き換える必要がある e.g. -0.1は実際に画像として表示させるときは0.9扱いになってしまうので、0に置き 換える – PyTorchならtorch.clampでできる 2. 生成画像の分布を分散固定の正規分布にしている場合は分布の平均を出力 するようにする – カラー画像の場合はだいたい該当する – 分散は学習の対象ではないため、分布からサンプリングするとノイズが乗ってしまう – Pixyzであれば、sampleではなく、sample_meanを使う 32
33.
結果 (Shepard-Metzler) 33 GroundTruth Prediction 2週間ほど回し続けた結果(71.5万ステップ) ほぼ見分けがつかないレベルで生成できるようになった 論文では200万ステップ回している(こちらのリソースでは1ヶ月くらいかかる…)
34.
結果 (Mazes) 34 GroundTruth Prediction 学習時間約2日
35.
得られた知見 ? 論文通りの実験を行うにはマシンパワーが 相当必要 – TiTAN
X (12GB)4枚でギリギリ乗るくらい – 著者はK80 (24GB)4枚で実験したらしい ? 潜在変数 ? の次元数を大きくしすぎると、 学習初期にKLを過剰に近づけるように学習 が進んでしまい、再構成が全く学習されな い – 表現力を高めるためにむやみに ? の次元を増 やすと逆効果 ? 論文通りまでパラメータを増やさなくても学 習は十分できる – 自己回帰のループを減らすなど 35 失敗するケース zのチャンネル数64 自己回帰8回で1日回した結果
36.
Link PyTorch https://github.com/iShohei220/torch-gqn Pixyz https://github.com/masa-su/pixyzoo/tree/master/GQN 36
Download