狠狠撸

狠狠撸Share a Scribd company logo
PyTorch, PixyzによるGenerative Query Networkの実装
1
松尾研B4 谷口尚平
自己紹介
? 名前:谷口尚平
? 所属:東京大学松尾研究室
? 学年:B4
? 研究領域
– 深層生成モデル(世界モデル)
– 強化学習のための状態表現学習
– マルチモーダル学習(言語?概念の創発)
? 今回は卒論に関連して進めていたGQNの
再現実装を共有させていただくため、参加
させていただきました。
2
アウトライン
1. Generative Query Networkとは
– 概要
– Variational Autoencoder
– DRAW
– GQNのアーキテクチャ的な工夫点
– (余談) メタ学習としてのGQN
2. 実装紹介
– Pixyzの紹介
– Representation Network
– Convolutional LSTM
– Generation Network
– その他
3
アウトライン
1. Generative Query Networkとは
– 概要
– Variational Autoencoder
– DRAW
– GQNのアーキテクチャ的な工夫点
– (余談) メタ学習としてのGQN
2. 実装紹介
– Pixyzの紹介
– Representation Network
– Convolutional LSTM
– Generation Network
– その他
4
概要
一言で言うと
ある視点からの観測が与えられたとき
に、別の視点からの観測を予測する深
層生成モデル
どうやって?
Conditional DRAWを用いています
DRAWとは?
VAEと自己回帰モデルの合わせ技のよ
うなものです
5
[DL輪読会]GQNと関連研究,世界モデルとの関係について
/DeepLearningJP2016/dlgqn-111725780
Variational Autoencoder
①入力から潜在変数 ? への推論
② ? から入力の再構成(生成)
をニューラルネットでモデル化した深層生成モデル
尤度 ? ? を最大化したいが、直接は測れないので、
以下の変分下限を最大化するように学習する
E ?(?|?) log
? ? ? ? ?
? ? ?
= E ?(?|?) log ? ? ? ? DKL[? ? ? ||? ? ]
6
? ?? Generation
?(?|?)
Inference
?(?|?)
Variational Autoencoder
直感的な理解の仕方
1. 潜在変数 ? を平均 0, 分散 1の正規分布からサンプリングしてニューラルネット
ワークに入れたら画像を生成するようにしたい
2. そこで中間表現を確率的な変数 ? にしたAutoencoderを用意して、画像の再構
成を学習させることにする
– 再構成誤差はdecoder (generator) の出力をベルヌーイ分布の平均値と考え、その尤度を
計算する
– ただし、カラー画像の場合は分散を固定した正規分布として尤度を計算することが多い
(GQNもそう)
3. それだけだと ? になんの制約もないので、encoder (inference)が出力する ? が
平均 0, 分散 1に近づくようにKL項を用意して制約をかける
? 理論的には変分ベイズが背景にあるので、詳しく知りたい人は以下の元論文とレ
ビュー論文を読むことをオススメします
– Auto-Encoding Variational Bayes https://arxiv.org/abs/1312.6114
– Tutorial on Variational Autoencoders https://arxiv.org/abs/1606.05908
7
Conditional VAE
? 通常のVAEでは事前分布 ?(?) を平均 0, 分散 1 で固定していたが、それをラベル ? で
条件付けた ?(?|?) とし、ニューラルネットでモデル化することで、ラベルで条件付けた
生成を可能にしたもの
e.g. 1の数字が書かれたMNIST画像を生成するなど
? この場合、最大化したいのは条件付き尤度 ? ?|? になるので、変分下限も以下のよう
に変わる
E ?(?|?,?) log
? ? ?, ? ? ?|?
? ? ?, ?
= E ?(?|?,?) log ? ? ?, ? ? DKL[? ? ?, ? ||? ?|? ]
8
?, ? ??, ? Generation
?(?|?, ?)
Inference
?(?|?, ?)
Prior
?(?|?) ?
KL
Conditional VAE
基本的にはGQNもこれがベースです
条件付けるラベル (?) が与えられた観測と生成する画像の視点になっただけ
– 与えられた観測 (context): M個の視点 ??
1..?
と観測 ??
1..?
のペア
– 生成する画像の視点 (query): ??
?
これらで条件付けて、queryに対応する観測画像 ??
?
を生成する
つまり、GQNは ?(? ?|? ?..?, ? ?..?, ? ?) をモデル化している
? 確率モデル的な枠組みはこれだけですが、GQNではアーキテクチャの面で
様々な工夫を行なっている
– DRAWもその1つ
9
DRAW
? Gregorらが提案したVAEと自己回帰モデルを組み合わせたモデル
? 確率モデル的にはVAEと同じ
? VAEにおける潜在変数 ? への推論をRNNを用いて複数回に分けて自己回
帰的に行うことで、モデルの表現力を高めている
? ? ? =
?=1
?
??(??|?, ?<?)
? 厳密にはDRAWを提案した論文では他にもAttentionを用いるなど様々な工
夫をしているが、一般的には VAEにおいて ? への推論を自己回帰的に行
う枠組みを総称してDRAWと呼ばれている(少なくともGQN系の論文ではそ
ういう扱い)
10
DRAW
変分下限
log ? ? ≥ E ? ? ? log ? ? ? ? DKL[? ? ? | ? ?
= E ? ? ? log ?(?|?) ? DKL[
?=1
?
?? ?? ?, ?<? ||
?=1
?
?? ?? ?<? ]
= E ? ? ? log ? ? ? ?
?=1
?
E ? ?<? ? [DKL ?? ?? ?, ?<? ?? ?? ?<?
? E ? ? ? log ? ? ? ?
?=1
?
DKL[??(??|?, ?<?)||??(??|?<?)]
? 分布の積のKLは各分布のKLの和でモンテカルロ近似できる
11
DRAW
アルゴリズム
for ? = 1 to ?
Prior Distribution ?? ?? ?<? = ? ? ?? ? ?
Encoder RNN ? ? = ??? ??? ?, ??, ? ?, ? ?
Posterior Sample ?? ~ ?? ?? ?, ?<? = ? ? ?? ? ?
Decoder RNN ? ? = ??? ??? ??, ? ?
KL Divergence DKL[?? ?? ?, ?<? ||?? ??|?<? ]
Canvas ?? = ?? + ? ?
? ?
Likelihood ?(?| ? ?)
? RNNを通して自己回帰的に潜在変数 ? を推論する
? 画像はそれぞれの ? から生成したものを重ね書きしていく
(絵を描くのに似ているのでDRAWという名前が付いている) 12
GQN
contextとqueryで条件付けたConditional DRAW
? これまでの内容を踏まえると、GQNの変分下限は以下のようになる
E ? ? ? ?, ?1..?, ?1..?, ? ? log ? ? ?
?, ?1..?
, ?1..?
, ? ?
?
?=1
?
DKL[?? ?? ? ?
, ?1..?
, ?1..?
, ? ?
||??(??|?1..?
, ?1..?
, ? ?
)]
? さらにGQNではcontextを圧縮するrepresentation networkを用意して、? =
?=1
?
?(? ?, ? ?) としているため、以下のように簡潔化できる
E ? ? ? ?, ?, ? ? log ? ? ?
?, ?, ? ?
?
?=1
?
DKL[?? ?? ? ?
, ?, ? ?
||??(??|?, ? ?
)]
これがGQNの目的関数です
13
GQN
GQNのアルゴリズム
for ? = 1 to ?
Prior Distribution ?? ?? ?<? = ? ? ?? ? ?
Encoder RNN ? ? = ??? ??? ? ?, ? ?, ?, ??, ? ?, ? ?
Posterior Sample ?? ~ ?? ?? ? ?, ? ?, ?, ?<? = ? ? ?? ? ?
Decoder RNN ? ? = ??? ??? ? ?, ?, ??, ? ?
KL Divergence DKL[?? ?? ?, ?<? ||?? ??|?<? ]
Canvas ?? = ?? + Δ ? ?
Likelihood ?(? ?
|? ?
(? ?))
? 基本的にはDRAWに条件付ける変数 (? ?, ?) が加わっただけ
? 最後に尤度をとる前にもう一度NNをかませているのが少し違うが、それ以外は
普通のConditional DRAW 14
アーキテクチャ的な工夫点
? Representation networkによって、context情報を圧縮して環境に関する事
前知識を決定論的な変数 ? として学習させている
– Neural Processes (Garnelo et al., 2018) でも提案されている手法
?? = ?(? ?, ? ?)
? =
?=1
?
??
– それぞれのcontext情報を合計している点が大きな特徴
– これはそれぞれのcontextが順序不変 (permutation invariant) であることを利用して
いる (e.g. 得られる観測は見る視点の順番が変わっても変わらない)
– もしこれをRNNにしてしまうと、後に得られた観測ほどcontextに大きな影響を与える
ことになってしまう
– ただし、合計することが最善なのかは議論の余地がある
15
(余談) メタ学習としてのGQN
GQNの問題設定は実はメタ学習と全く同じ
– 任意のデータセット(ここではシーン)が与えられた時に入力(視点)から出力(観測)
へ正しく写像できるように、すべてのデータセットで共有されるパラメータをメタ知識と
して学習させる
– GQNではrepresentation networkがメタ知識を学習するパラメータを保持している
? NNによるメタ学習の主なアプローチは以下の2通り
① メタ知識をパラメータの初期値として学習させる (MAML)
② メタ知識を学習するためのネットワークを別に用意する
– GQN (Neural Processes) は②に当たる
? permutation invariantな問題設定では、GQNのようなメタ学習的なアプロー
チが有効であることを示したのも大きな貢献(世の中には同じようなアプ
ローチで解ける問題がたくさん眠っているはず)
16
アウトライン
1. Generative Query Networkとは
– 概要
– Variational Autoencoder
– DRAW
– GQNのアーキテクチャ的な工夫点
– (余談) メタ学習としてのGQN
2. 実装紹介
– Pixyzの紹介
– Representation Network
– Convolutional LSTM
– Generation Network
– その他
17
過去のGQN実装
TensorFlow
https://github.com/ogroth/tf-gqn
– 一番早く上がった実装
– 死ぬほどわかりづらい
– どうすればこんなにわかりづらく書けるのかわからない
– ちゃんと動くっぽい
Chainer
https://github.com/musyoku/chainer-gqn
– musyokuさんの実装
– わかりやすい
– データセットを自作するコードも公開(さすがの実装力)
– 結果のクオリティが高い
PyTorch
https://github.com/wohlert/generative-query-network-pytorch
– めちゃくちゃわかりやすい
– 学習が全くうまくいっていない
→ PyTorchだけまともに動くものがない
18
Pixyz
弊研?鈴木雅大作のPyTorchベースの深層生成モデル用ライブラリ
? ネットワークを確率モデルで隠蔽する構造になっているため、尤度やKL
Divergenceなど、確率分布間の操作をネットワーク構造を意識することなく
実装できるため、コードの可読性が上がる。
? TensorFlow Probability (Edward) などはネットワークと確率モデルを並列に
扱っているため、深層生成モデル用としてはやや使いづらい
19
実装
? 今回はPyTorch, Pixyzの両方で実装しました。
? 基本的にはほぼ同じですが、両者を比較しながらPixyzの使い方も紹介して
いきたいと思います。
注:現在のPixyzはLoss APIが自己回帰モデルに対応していないため、現状では
Pixyzを使うメリットがあまりありません。近々、自己回帰に対応したバー
ジョンがリリースされるので、そのタイミングでPixyz版のアップデートを
行う予定です。
? 論文内では一部のハイパーパラメータに関する記述がなかったため、著者
のEslamiさんに確認をとりました。
(おそらく元のPyTorch実装が動かないのはここのハイパラが間違っている
から)
20
論文に記述のないハイパーパラメータ
? 基本的に変数の次元数に関する記述が全くないため著者に確認
– 潜在変数 ? : batch_size x 3 x 16 x 16
– RNNの隠れ変数 ?, ? : batch_size x 128 x 16 x 16
– DRAWのcanvasとなる ? : batch_size x 128 x 64 x 64
– 潜在変数 z の次元数は学習に大きな影響を与えるようなので注意
(チャンネル数を大きくしすぎると全く学習が進みません)
? RNNへの入力のサイズを合わせるためのdown-sample, up-sampleのネット
ワーク
– すべてbias項なしの畳み込み or 逆畳み込みを1層用意し、サイズがRNNの隠れ変数
と同じになるようにする(channel数は変えない)
e.g. 画像は、kernel_size: 4x4, stride: 4x4, padding: 0 の畳み込みをしてから
RNNに入力する 21
dataset/convert2torch.py
? DeepMindが公開しているデータセット
https://github.com/deepmind/gqn-datasets
? 元データセットはTensorFlow用のものなので、PyTorch用に変換する
? Shepard-Metzlerは約1日で変換できました
? To do: マルチプロセス化
22
gqn_dataset.py
データローダ周りのスクリプト
sample_batch
– 1つのデータには1つのシーンに相当する画像と視点のペア群が入っているため、学
習時にはそこからcontextに使うものとqueryに使うものをランダムにサンプリングして
使用する
23
representation.py
Representation networkの実装
24
論文で提案されている3種類を実装
Poolが一番いいらしい
conv_lstm.py
Coreで使うConvolutional LSTMの実装
25
Core
EncoderRNN, DecoderRNNに当たる部分
26
入力のサイズを合わせてConvLSTMに入れる
Distributionクラス (Pixyz)
27
ネットワークを確率分布で隠蔽するクラス
ここでは正規分布なので、平均 (loc) と標準
偏差 (scale) を辞書形式で返す
初期化の際に、条件づける変数 (cond_var) と
変数 (var) をリストで渡す。
forwardの入力とcond_varを揃えることに注意
model.py
28
PyTorch Pixyz
*_coreが自己回帰の部分を担うConvolutional LSTM
Pixyzではeta_* の代わりに笔谤颈辞谤などの辫颈虫测锄.诲颈蝉迟谤颈产耻迟颈辞苍蝉クラスのインスタンスを立てる
model.py
29
Pixyzではネットワークを
確率モデルで隠蔽している
ため、q.sampleなどとする
だけで分布からのサンプリ
ングが可能で、コードが読
みやすくなる!
PyTorch Pixyz
train.py
30
負の変分下限をロス関数として学習する。
学習率と生成画像の分散のアニーリングはここで行
う。
TensorBoardXを用いてログを保存
train.py
オプション
? デフォルトでは、すべて論文通りの実装が動くようになっていますが、GPU
のメモリが多くない場合のために、train.pyに以下のオプションがあります
1. --layers
自己回帰のループ数 (default: 12)
8くらいでも十分な結果が得られます
2. --shared_core
自己回帰のループ間で同じRNNを使うかどうか (default: False)
これがパラメータをめちゃくちゃ増やしていますが、Trueでもほぼ同じ結果が得られます
(論文内でもFalseだと少しだけ結果が良くなると書いてある)
31
画像生成モデルの実装上の注意点
1. 最後の画像の出力を [0, 1] に抑える
– 初めて画像生成系の実装をするときにやりがちなミス
– 学習はできてるっぽいのに、生成画像に変なノイズが乗っているときは大抵これをし忘れて
いる
– モデルの出力は最後がsigmoid等になっていない限り [0, 1] になるとは限らないので、<0 or
>1 の部分を0, 1に書き換える必要がある
e.g. -0.1は実際に画像として表示させるときは0.9扱いになってしまうので、0に置き
換える
– PyTorchならtorch.clampでできる
2. 生成画像の分布を分散固定の正規分布にしている場合は分布の平均を出力
するようにする
– カラー画像の場合はだいたい該当する
– 分散は学習の対象ではないため、分布からサンプリングするとノイズが乗ってしまう
– Pixyzであれば、sampleではなく、sample_meanを使う
32
結果 (Shepard-Metzler)
33
GroundTruth Prediction
2週間ほど回し続けた結果(71.5万ステップ)
ほぼ見分けがつかないレベルで生成できるようになった
論文では200万ステップ回している(こちらのリソースでは1ヶ月くらいかかる…)
結果 (Mazes)
34
GroundTruth Prediction
学習時間約2日
得られた知見
? 論文通りの実験を行うにはマシンパワーが
相当必要
– TiTAN X (12GB)4枚でギリギリ乗るくらい
– 著者はK80 (24GB)4枚で実験したらしい
? 潜在変数 ? の次元数を大きくしすぎると、
学習初期にKLを過剰に近づけるように学習
が進んでしまい、再構成が全く学習されな
い
– 表現力を高めるためにむやみに ? の次元を増
やすと逆効果
? 論文通りまでパラメータを増やさなくても学
習は十分できる
– 自己回帰のループを減らすなど
35
失敗するケース
zのチャンネル数64
自己回帰8回で1日回した結果
Link
PyTorch
https://github.com/iShohei220/torch-gqn
Pixyz
https://github.com/masa-su/pixyzoo/tree/master/GQN
36

More Related Content

PyTorch, PixyzによるGenerative Query Networkの実装