TabNetを頑張って調べて見たりする遊び(2/2)

実装を少し端折る

前回に引き続き、今回は実装視点でみてみる
Kaggleにいい感じにシンプルな実装をみつけたので、それを参考にします。
https://www.kaggle.com/marcusgawronsky/tabnet-in-tensorflow-2-0
後述のとおり、若干掛けている部分があるので、そこだけピックアップします

その前に重要な機能を前回よりも細かく列挙

FeatureTransformer

  • Skip構造
    • ResNetなどに代表されるスキップ構造を導入
    • 最後(square(0.5))は結果の変動を抑止(学習安定化)するための安全策らしい。理由は不明
  • ShareとDecisionステップ
    • 構造的な違い
      • ShareStep:使用するFCは全て共通
      • Decision Step:使用するFCは全て異なる
        • スキップ可能
      • 参考サイトではShareStepとDecisionは1つのみ、Googleの実装も1,2つずつくらいで決まっていない疑惑。なので可変変数にしてみた
    • GhostBatchNormalization
      • ここでのBNは全てこれで、小分けにBNする手法
      • 解説記事1解説記事2TensorFlowでは既にオプションに組み込まれており,virtual_batch_sizeで小分けサイズを指定可
      • 元論文でもどうしてうまくいくか分からないらしいが、小分けにBNすることで大バッチにおける汎化誤差軽減を目論む
      • 今回テーブルdataという大量のdataに対して、「大きなバッチサイズ」で摘要すると書いてあるので、それに対応するためと推測
  • GLU
    • Gated Linier Unit;線形ゲート付きユニット
    • 中身は非常に単純で下記式はA=Bで、入力をシグモイドに変換し、元の入力と積をとる
      • これもアテンション機構+決定木の特徴量選定の一つだと推測
      • 結局入力信号の篩を掛けているだけだと思います。でもこのブログで取り上げられている通り、かますことで性能が上がるらしいです。

AttentiveTransformer

  • Sparsemax
    • softmaxよりもとんがった分類を行うらしく、論文ではマルチ分類に有効らしい
    • ablation testではRELUよりもよかったぞ
    • 決定木はいわゆる「ぶった切っ」て「結果を2分」するものなので、より鋭い活性化が必要だったと推測
  • Prior
    • 各ステップごとに係数を決定される。[過去のPrior*(gamma-Attentionからのスカラー値)]として流用される
      • これがBoostingの前回の反省を生かす仕組みの実装なのかは不明
        • これに関する実験記述なし
      • このgammaは全ステップ共通。論文中ではスパース性への対応力が変化すると記載あり。

Split

  • Relu行きor Attention行きかで若干ことなる
    • Relu行き:次元数n_d
    • Attention行き:次元数n_a
    • →つまり今更だけれど、FeatureTransの段階でn_d+n_aに変換する必要が実はある。最初の層は無理矢理変えても問題なさそう
    • →実はこの部分はGoogleAIの実装ではスキップされている。。。

上記をまとめて、(1)shre stepとdecision steps数可変化と(2)n_d, n_aの設定部分が不足しているので、その部分だけ切り抜いて実装した

もちろん、stepのClassをつくるけれども、その際にもコレらの変数は何らかの方法で受け渡すので
その実装は必要。

事前学習機構

今回のおもしろい部分。FeatureTranとAttentionをコアとするEncoderは、単純に入力と正解与えて学習することも可能だが、AutoEncoderの仕組みを構築して事前学習することで更に精度向上するとか、、、

Unsupervised pre-traininng中の表中「?」に対して、下段で学習機が当てに行っているように見えるけど、
仕組みとしては下記だと推測する

  • ?の部分は各ステップのマスキングを意味する。よって各ステップ内で精錬される。
    • 結局各ステップでマスキングが行われるため、毎回ことなる入力と出力になる
    • 学習時もエントロピでの評価

Decoderの仕組みも、Encoderと同一。逆につないだだけ。

一旦ここでおしまい

実装はこちらです。前述のとおり,kagleのコードを流用しているのでdata取得にはkaggleコマンドを通過させるようにする必要ありです

所感としては、なんだろう。内容としてはいつもとは違うモヤモヤが残る気がします
DNNだと結構コジツケ感があっても、まぁいいやというふうになるが
GBDTと似てます!っていうと決定木やBoostingの構造があるように見えるが、その周辺解説が少ないように見える。まぁシッカリ読んでいないからかもしれない

次回もゆるく何かを纏める

補足:すでにまとまっているニホンゴ資料