MATLAB ユーザーコミュニティー

MATLAB & Simulink ユーザーコミュニティー向け日本語ブログ

深層距離学習 Deep Metric Learning でチワワ ? ポメラニアン?

今回は最近入社されたアプリケーションエンジニア福地さんの投稿です。
愛犬を紹介するために深層距離学習の ArcFace を実装するという前代未聞の自己紹介プレゼンを繰り出していただきました。これを社内にとどめておくのはもったいなさすぎる・・ということで無理言ってこちらに登場頂きました。
コードも Github: Implementation-ArcFace-in-MATLAB に公開されてますので、ぜひ試してみてください!

こんにちは。アプリケーションエンジニアリング部 自立システム・ロボティクス担当の福地です。
昨年 12 月に入社し、社内のみなさんに自己紹介をする機会がありました。ありがたいことに、その内容についてブログにしてみないかとお声がけいただいたのでその裏側をご紹介します!
自己紹介では ArcFace と呼ばれる、深層距離学習の手法を使ったお遊びをやってました。なぜそんなことをやったのかというと

何かにこじつけて飼っているわんこのかわいさを紹介したかったから

です!!!
思考フローはこんな感じです。
  1. 自己紹介何を話そうかな~、やっぱり覚えてもらうには自分の好きなものを熱く語ったほうが良さそう!
  2. 好きなものといえば犬!ついでうちの飼っているわんこを紹介しよう!
  3. でもいくらかわいいとはいえ単純に画像を貼り付けていると、飽きちゃうしそもそも自己紹介じゃなくて犬紹介になっちゃうな。。。
  4. そうだ!MATLABで何か実装してそのサンプルとしてうちの子を使おう!
  5. 実装するならここ10年くらいずっと流行りのDeep Learning かつ、 MATLAB公式実装では用意されてなくて自前実装が必要 だと会社の皆さんのウケが良さそう(あと、実装が比較的簡単にできればもっといい(笑))
と考えていく中でArcFaceにたどり着きました!
ArcFaceは
  1. 距離学習の一つであり、「未学習の人の顔の類似度の計算」や「異常検知」などができて魅せ方が色々考えられそう!
  2. 分類用のネットワークの最終出力を変更するだけで実装できて簡単 (MATLAB なら画像用に pretrain 済みの ResNet がある。最終層をいじって転移学習でつくれそう!)
ということで良さそうだな、と選びました。ArcFace を MATLAB 実装したコードは Github に公開しております!そのまま学習、テストできるのでぜひ一度動かしてみてください!
ちなみにこれがうちのかわいい愛犬です!「こまち」といいます!
はい、かわいい 紹介できてよかった
先に簡単に紹介しましたが、深層距離学習の手法の一つです。通常の分類問題に加えて論文中で “margin loss” と呼ばれるlossを導入することで、下の画像における、「クラス 1 の中心 $ W_1 $ とクラス 1 のサンプル $ x_{11} $ 間の距離 $ d_{12} $ がクラス 2 の中心 $ W_2 $ とクラス 1 のサンプル $ x_{11} $ 間の距離 $ d_{22} $ に対して設定したマージン m 以上の差が生まれる」ようにしています。したがって、学習された結果各サンプルは、分類されたクラス中心との距離がより近く他のクラス中心との距離はより遠くなるように特徴量を学習します。ArcFace ではこのクラス中心とサンプルの特徴を正規化した大きさ1の特徴空間(=超球面上)で表現することで、距離を cosine 類似度で表現できます。
そのために、クラス分類の最終層である softmax 層の前に、特徴空間の正規化+正解クラス中心へのマージン追加 を行う層を追加(または、全結合層の置換)します。逆に言えば、それだけでサンプル同士やクラス中心との距離の概念(=cosine 類似度)を導入できる優れモノです。このマージン m は、学習時に Loss が大きくなるよう、正解クラスを出力しづらくなるように Penalty を与えるような処理と解釈できます。正規化をしたままでは、同一クラスの特徴を近くに、異なるクラスの特徴を遠くに埋め込むための表面積が限られているので、最後パラメータ s でスケール倍して softmax に渡します。m, s が ArcFace のハイパーパラメータになります。
下記は通常の分類結果と ArcFace の分類結果を表示している論文中の図です。各色がそれぞれクラスを表しており、各点がそのクラスに分類されたサンプル、線はクラス中心を示してます。二次元なので円上(超球面が二次元では円)にクラスサンプルとクラス中心が存在してます。softmax のみの学習に対して ArcFace での学習結果のほうが各クラスがコンパクトにまとまっており、各サンプルについて、分類されたクラスのクラス中心との距離がより近く他のクラス中心との距離はより遠くなっていることがわかります。
突然ですが、みなさんは先ほど載せていた「こまち」の犬種が何かわかりますか?近所を散歩させていると、よく「柴犬?」とか「チワワ?」とか声をかけられます。みなさん自信が無さそうなのでどこか雰囲気の違いを感じ取られているのだと思います。正解は「ポメラニアン×柴犬の MIX」です!ポメラニアンの血が入っているのですが、不思議なことにポメラニアンか尋ねられることはあまりありません。おそらくポメラニアンはまだ飼われている数が多くないから(最近は増えてきてるらしいですが)、知っている犬種で推定するバイアスがかかっているのではないかと思います。令和2年全国犬猫実態調査によると柴犬、チワワの数は圧倒的ですね
そこで、今回は ArcFace の「サンプルとクラス中心の距離」が計算できる点を活用しようと思います。こまちの画像をサンプルとして入力し、「柴犬」クラスや「ポメラニアン」クラスとの距離= cosine 類似度を計算することで、AI 的にこまちがどの犬種に似ているのか、それぞれの犬種との類似度を出してもらおうと思います!
今回は Oxford pet dataset を使用して学習しました!37種類のペット画像がそれぞれ 200 枚程公開されており、しかも、顔部分にbounding box が振られているため、犬種の類似度を簡単に算出できそうです!「こまち」は柴犬、ポメラニアン、チワワに似ていると散歩中に言われましたが、その 3 犬種も入っているのでそれぞれとの距離を測れます!
学習をまわしてみたところ、いい感じです!公式ドキュメントに学習の一連の流れが載っておりちょこっと変更するだけでグラフの可視化が簡単にできるのが嬉しいですね。最終的に良さそうだと思って学習結果がこちらです。
上のグラフが Loss,下のグラフが 分類問題の正解率です。このグラフを見ながら順調に Loss が下がって正解率が上がるようにハイパーパラメータを調整してました!学習データで 98 %、検証データでも正解率が 94 %なので、お遊びの今回はこのくらいでよいかなとということで採用しました!(本当は「実験マネージャ」でハイパーパラメータのサーチを行ったり、データ拡張を真面目にやった方が良い結果はでるんでしょうけど(笑))
学習後の各クラスとの cosine 類似度を表示した結果がこちらです!正解ラベルの「newfoundland」に Penalty が付与されて類似度が減少しているのがわかります。そのおかげで学習時は Loss が大きめに出て、より正解との類似度が大きく、他の類似度を小さくするように学習が進んだんですね。おかげで Penalty なしの出力では、他クラスの出力と類似度で大きなギャップがあるのがわかります。
学習もそこそこうまくいったので、いよいよ「こまち」がどの犬種に似ているのか判定させてみたいと思います。期待としては、やっぱり「柴犬」と「ポメラニアン」の両方の類似度が高いことですが・・・!
ん・・・?
確かに似ていると言われているチワワや、柴犬、ポメラニアンの類似度が大きい結果が出たのですが、なぜチワワが一番高い?やはり似ているのか。。。しかし、画像違いでは柴犬一択のような結果になっている。写真の取り方や見える角度でどの犬種に似ているのかが変わる犬ってことなのか・・・何を以てしてそれぞれの判定をしているのかが気になるところです。GradCAMocclusionSensitivity とかを使えば、耳とか目とかどこの部位でそれぞれのクラス判定が起きているのかがわかりそうですね(チワワと判定されるのは、チワワとポメラニアンが両方まん丸黒目が似ているからな気がします!そこらへんがわかったらおもしろいですね!)

まあ一匹で3犬種分かわいいってことかな

今回は実装した ArcFace と愛犬の紹介をさせていただきました!ArcFaceでは、サンプルの特徴量とクラス中心との類似度を計算して「こまち」がどの犬種に似ているかを判定してました。人の顔分類について学習させてサンプル同士の特徴量の類似度を比較すれば、顔の類似度比較なんかもできるようです(というか、本当はそっちがメイン)。また、今回は手動で設定したハイパーパラメータ()を最適に設定する AdaCos というのも提案されているようです。日々新しい手法が提案されていてわくわくしますね!
今回私が作成したプログラム、データは丸ごと Github で公開しています。是非ご活用ください。
最後まで読んでくださり、ありがとうございました。

|
  • print

コメント

コメントを残すには、ここ をクリックして MathWorks アカウントにサインインするか新しい MathWorks アカウントを作成します。