狠狠撸

狠狠撸Share a Scribd company logo
  	
Dropout Distillation	
Samuel Rota Bulò, Lorenzo Porzi , Peter Kontschieder 	
	
ICML2016読み会	
紹介者:佐野正太郎	
株式会社リクルートコミュニケーションズ
(C)Recruit Communications Co., Ltd.	
背景:Dropout学習	
?? ニューラルネットワークの過学習を抑制する手法
?? 学習ステップ毎にランダムに一部のユニットを落とす
?? 暗に多数のネットワークのアンサンブルモデルを学習している
?? [Srivastava et al., 2014]
1
学習対象のネットワーク	
 学習ステップ1	
 学習ステップ2	
???	
学習時
(C)Recruit Communications Co., Ltd.	
背景:Dropoutにおける予測計算	
2
Doropout学習時にはネットワーク構造がランダム	
 =>	
 予測時にどの構造を採用するか?
理想:全てのDropoutパターンでの予測計算の期待値をとる
Standard Dropout [Srivastava et al., 2014]
?? 予測時にはユニットを落とさない
?? 各ユニットの出力を (1 – dropout率) でスケールすることで実用的な精度が得られる
Monte-Carlo Dropout [Gal & Ghahramani, 2015]
?? 予測時に複数のDropoutパターンを試して平均をとる
?? 予測の計算コストが高い代わりにStandard Dropoutよりも良い精度が得られる
(C)Recruit Communications Co., Ltd.	
背景:Distillation	
3
Distilling the knowledge in Neural Network [Hinton et al., 2014]
?? distill = 蒸留する
?? 複数のネットワークや複雑なネットワークを単一の小さなモデルに圧縮する手法
蒸留モデル	
アンサンブルモデル
(C)Recruit Communications Co., Ltd.	
提案手法:Dropout Distillation	
概要
?? Dropout学習が暗に獲得しているアンサンブルモデルを圧縮/蒸留(Distillation)する
?? Dropout学習後モデルのMonte-Carlo予測を模倣する新しいモデルを学習する
利点
?? Standard Dropoutと同じ予測計算コストでStandard Dropoutよりも高い予測精度
?? 半教師あり学習への応用可能性:教師信号が欠損したデータをDistillationフェーズで活用できる
?? モデル圧縮への応用可能性:Dropuoutで複雑なモデルを学習してDistillationフェーズで圧縮できる
欠点
?? Distillationフェーズに余計な時間がかかる	
 
4
(C)Recruit Communications Co., Ltd.	
提案手法:Dropout Distillation	
5
Dropout	
学習済み	
モデル	
生徒モデル	
損失関数	
Dropout	
パターン
(C)Recruit Communications Co., Ltd.	
提案手法:Dropout Distillation	
6
教師モデル	
(Dropout学習済み)	
生徒モデル	
Distillationフェーズでは	
教師モデルの振る舞いを真似るよう	
生徒モデルを学習する	
通常のDropout学習で	
教師となるモデルを獲得
(C)Recruit Communications Co., Ltd.	
提案手法:Dropout Distillation	
7
Distillation用	
学習データ	
(教師信号無し)	
教師モデル	
(Dropout学習済み)	
生徒モデル	
生徒モデルの出力	
出力間の損失を	
埋めるように	
生徒モデルの	
パラメタを更新	
教師モデルの出力	
生徒モデルには	
ドロップアウトをかけない	
教師モデルにドロップアウトを	
かけながら出力データを生成
(C)Recruit Communications Co., Ltd.	
提案手法:Dropout Distillation	
8
Distillation用	
学習データ	
(教師信号無し)	
教師モデル	
(Dropout学習済み)	
生徒モデル	
生徒モデルの出力	
教師モデルの出力	
教師モデルと生徒モデルの	
ネットワーク構造は違っていてもよい	
データはDropoutフェーズから流用可	
新しいデータを用意するのも可
(C)Recruit Communications Co., Ltd.	
理想の予測関数
?? 全てのDropoutパターンでの出力期待値
?? Dropoutパターンはユニット数に対し指数関数的に増加するので事実上計算できない
問題設定
?? 『理想の予測関数』を教師モデルとした生徒モデルを学習したい
どうやって『理想の予測関数』を計算に取り入れるか?
Dropout学習済みモデル	
導出	
9
理想の予測関数	
損失関数	
生徒モデル	
評価できない	
Dropoutパターン
(C)Recruit Communications Co., Ltd.	
アプローチ
?? 『理想の予測関数』をDropout学習済みモデルで置き換える
?? 損失関数がBregmanダイバージェンスのとき以下の最小化問題が等価
Bregmanダイバージェンス
?? 二乗損失?Logistic損失?KLダイバージェンスなどを一般化したもの
Dropout	
学習済み	
モデル	
導出	
10
生徒モデル	
微分可能な凸関数	
Dropoutパターン	
この表現を形にしたのが	
スライド5?8のアルゴリズム
(C)Recruit Communications Co., Ltd.	
証明
qに関係ないので定数とみなせる	
導出	
11
本来の最小化対象	
Dropout Distillationでの最小化対象
(C)Recruit Communications Co., Ltd.	
実験1:予測計算手法による性能比較	
12
MNIST/CIFAR10/CIFAR100データセットで3予測手法のエラー率比較
?? Standard Dropout
?? Monte-Carlo Dropout(100サンプリング)
?? Dropout Distillation
実験手順
1.? Dropout学習でベースラインモデルを獲得(300エポック)
2.? ベースラインモデルでStandard DropoutとMonte-Carlo Dropoutの性能評価
3.? ベースラインモデルを教師としてDropout Distillation(30エポック)
–? 生徒モデルのネットワーク構造はベースラインモデルと同様
–? ベースラインモデルの学習後パラメタで生徒モデルを初期化
–? ベースラインモデルの入力データを流用(pixel毎に確率0.2で値をゼロ化)
4.? 生徒モデルでDropout Distillationの性能評価
(C)Recruit Communications Co., Ltd.	
実験1:予測計算手法による性能比較	
13
?? 平均エラー率は Standard > Distillation > Monte-Carlo の順
?? Monte-CarloよりDistillationの方がパフォーマンスの分散は低い
(C)Recruit Communications Co., Ltd.	
実験2:Distillationに使うデータセットによる性能比較	
14
Distillationフェーズの入力データについて3シナリオで性能比較
?? [Train] 教師モデルのトレーニングセットをそのまま利用
?? [Pert. Train] 教師モデルのトレーニングセットをピクセル毎に確率0.2で値をゼロ化
?? [Test] テストデータを利用
どのシナリオが	
優れているかは	
場合による
(C)Recruit Communications Co., Ltd.	
実験3:モデル圧縮への応用可能性	
15
CIFAR10/Quickでユニット数を削減した場合のパフォーマンス変化
?? [Baseline] Dropout学習のみで削減後モデルを学習
?? [Distillation] Dropoutフェーズで削減前モデルを学習してDistillationフェーズで削減後モデルに圧縮
青枠内では『Dropoutフェーズで複雑なモデルを学習 => Distillationフェーズで圧縮』が有効に働いている
FC層からのみユニットを削った場合	
 全層からフィルタ/ユニットを削った場合
(C)Recruit Communications Co., Ltd.	
従来手法
?? Standard Dropout:予測時間が短いけど精度が低め
?? Monte-Carlo Dropout:予測時間が長いけど精度が高め
提案手法の主な貢献
?? Standard Dropoutと同じオーダーの予測時間
?? 安定してStandard Dropoutよりも良い精度が出る
場合によって効いてくるメリット
?? 教師信号が欠損したデータをDistillationフェーズで活用
?? Dropuoutで複雑なモデルを学習してDistillationフェーズで圧縮
まとめ	
16

More Related Content

Dropout Distillation

  • 1.    Dropout Distillation Samuel Rota Bulò, Lorenzo Porzi , Peter Kontschieder ICML2016読み会 紹介者:佐野正太郎 株式会社リクルートコミュニケーションズ
  • 2. (C)Recruit Communications Co., Ltd. 背景:Dropout学習 ?? ニューラルネットワークの過学習を抑制する手法 ?? 学習ステップ毎にランダムに一部のユニットを落とす ?? 暗に多数のネットワークのアンサンブルモデルを学習している ?? [Srivastava et al., 2014] 1 学習対象のネットワーク 学習ステップ1 学習ステップ2 ??? 学習時
  • 3. (C)Recruit Communications Co., Ltd. 背景:Dropoutにおける予測計算 2 Doropout学習時にはネットワーク構造がランダム => 予測時にどの構造を採用するか? 理想:全てのDropoutパターンでの予測計算の期待値をとる Standard Dropout [Srivastava et al., 2014] ?? 予測時にはユニットを落とさない ?? 各ユニットの出力を (1 – dropout率) でスケールすることで実用的な精度が得られる Monte-Carlo Dropout [Gal & Ghahramani, 2015] ?? 予測時に複数のDropoutパターンを試して平均をとる ?? 予測の計算コストが高い代わりにStandard Dropoutよりも良い精度が得られる
  • 4. (C)Recruit Communications Co., Ltd. 背景:Distillation 3 Distilling the knowledge in Neural Network [Hinton et al., 2014] ?? distill = 蒸留する ?? 複数のネットワークや複雑なネットワークを単一の小さなモデルに圧縮する手法 蒸留モデル アンサンブルモデル
  • 5. (C)Recruit Communications Co., Ltd. 提案手法:Dropout Distillation 概要 ?? Dropout学習が暗に獲得しているアンサンブルモデルを圧縮/蒸留(Distillation)する ?? Dropout学習後モデルのMonte-Carlo予測を模倣する新しいモデルを学習する 利点 ?? Standard Dropoutと同じ予測計算コストでStandard Dropoutよりも高い予測精度 ?? 半教師あり学習への応用可能性:教師信号が欠損したデータをDistillationフェーズで活用できる ?? モデル圧縮への応用可能性:Dropuoutで複雑なモデルを学習してDistillationフェーズで圧縮できる 欠点 ?? Distillationフェーズに余計な時間がかかる 4
  • 6. (C)Recruit Communications Co., Ltd. 提案手法:Dropout Distillation 5 Dropout 学習済み モデル 生徒モデル 損失関数 Dropout パターン
  • 7. (C)Recruit Communications Co., Ltd. 提案手法:Dropout Distillation 6 教師モデル (Dropout学習済み) 生徒モデル Distillationフェーズでは 教師モデルの振る舞いを真似るよう 生徒モデルを学習する 通常のDropout学習で 教師となるモデルを獲得
  • 8. (C)Recruit Communications Co., Ltd. 提案手法:Dropout Distillation 7 Distillation用 学習データ (教師信号無し) 教師モデル (Dropout学習済み) 生徒モデル 生徒モデルの出力 出力間の損失を 埋めるように 生徒モデルの パラメタを更新 教師モデルの出力 生徒モデルには ドロップアウトをかけない 教師モデルにドロップアウトを かけながら出力データを生成
  • 9. (C)Recruit Communications Co., Ltd. 提案手法:Dropout Distillation 8 Distillation用 学習データ (教師信号無し) 教師モデル (Dropout学習済み) 生徒モデル 生徒モデルの出力 教師モデルの出力 教師モデルと生徒モデルの ネットワーク構造は違っていてもよい データはDropoutフェーズから流用可 新しいデータを用意するのも可
  • 10. (C)Recruit Communications Co., Ltd. 理想の予測関数 ?? 全てのDropoutパターンでの出力期待値 ?? Dropoutパターンはユニット数に対し指数関数的に増加するので事実上計算できない 問題設定 ?? 『理想の予測関数』を教師モデルとした生徒モデルを学習したい どうやって『理想の予測関数』を計算に取り入れるか? Dropout学習済みモデル 導出 9 理想の予測関数 損失関数 生徒モデル 評価できない Dropoutパターン
  • 11. (C)Recruit Communications Co., Ltd. アプローチ ?? 『理想の予測関数』をDropout学習済みモデルで置き換える ?? 損失関数がBregmanダイバージェンスのとき以下の最小化問題が等価 Bregmanダイバージェンス ?? 二乗損失?Logistic損失?KLダイバージェンスなどを一般化したもの Dropout 学習済み モデル 導出 10 生徒モデル 微分可能な凸関数 Dropoutパターン この表現を形にしたのが スライド5?8のアルゴリズム
  • 12. (C)Recruit Communications Co., Ltd. 証明 qに関係ないので定数とみなせる 導出 11 本来の最小化対象 Dropout Distillationでの最小化対象
  • 13. (C)Recruit Communications Co., Ltd. 実験1:予測計算手法による性能比較 12 MNIST/CIFAR10/CIFAR100データセットで3予測手法のエラー率比較 ?? Standard Dropout ?? Monte-Carlo Dropout(100サンプリング) ?? Dropout Distillation 実験手順 1.? Dropout学習でベースラインモデルを獲得(300エポック) 2.? ベースラインモデルでStandard DropoutとMonte-Carlo Dropoutの性能評価 3.? ベースラインモデルを教師としてDropout Distillation(30エポック) –? 生徒モデルのネットワーク構造はベースラインモデルと同様 –? ベースラインモデルの学習後パラメタで生徒モデルを初期化 –? ベースラインモデルの入力データを流用(pixel毎に確率0.2で値をゼロ化) 4.? 生徒モデルでDropout Distillationの性能評価
  • 14. (C)Recruit Communications Co., Ltd. 実験1:予測計算手法による性能比較 13 ?? 平均エラー率は Standard > Distillation > Monte-Carlo の順 ?? Monte-CarloよりDistillationの方がパフォーマンスの分散は低い
  • 15. (C)Recruit Communications Co., Ltd. 実験2:Distillationに使うデータセットによる性能比較 14 Distillationフェーズの入力データについて3シナリオで性能比較 ?? [Train] 教師モデルのトレーニングセットをそのまま利用 ?? [Pert. Train] 教師モデルのトレーニングセットをピクセル毎に確率0.2で値をゼロ化 ?? [Test] テストデータを利用 どのシナリオが 優れているかは 場合による
  • 16. (C)Recruit Communications Co., Ltd. 実験3:モデル圧縮への応用可能性 15 CIFAR10/Quickでユニット数を削減した場合のパフォーマンス変化 ?? [Baseline] Dropout学習のみで削減後モデルを学習 ?? [Distillation] Dropoutフェーズで削減前モデルを学習してDistillationフェーズで削減後モデルに圧縮 青枠内では『Dropoutフェーズで複雑なモデルを学習 => Distillationフェーズで圧縮』が有効に働いている FC層からのみユニットを削った場合 全層からフィルタ/ユニットを削った場合
  • 17. (C)Recruit Communications Co., Ltd. 従来手法 ?? Standard Dropout:予測時間が短いけど精度が低め ?? Monte-Carlo Dropout:予測時間が長いけど精度が高め 提案手法の主な貢献 ?? Standard Dropoutと同じオーダーの予測時間 ?? 安定してStandard Dropoutよりも良い精度が出る 場合によって効いてくるメリット ?? 教師信号が欠損したデータをDistillationフェーズで活用 ?? Dropuoutで複雑なモデルを学習してDistillationフェーズで圧縮 まとめ 16