TabNetを頑張って調べて見たりする遊び(1/2)

TabNetとの出会い

TabNetは年末にKaggleで出現した手法で、Tableデータに対して有効なDNN モデルだそうです。
画像・音声認識などには疎く、時系列などにフォーカスしている身としては、抑えておきたい内容

論文のContribution を纏めるとこのようにかいていますね。

  • 前処理なしで勾配降下法に基づく最適化によって学習可能
    • 後述でGBDTのようにできると記述あり
  • 各決定ステップで推論する特徴を選択するためにAttentionを使用。
  • 特徴選択と推論のための単一の深層学習アーキテクチャを採用。
  • さまざまなドメインの分類および回帰問題に関するデータセットにおいて、他の表計算モデルよりも優れているか同等。
  • 特徴重要度(interpretability)とそれらの組み合わせ方を視覚化する局所的解釈性、各特徴の学習モデルへの貢献度を定量的に示す大域的解釈性。
  • 教師なしの事前学習を用いてマスクされた特徴を予測することにより,大幅な性能向上を示した.
  • マスクされた特徴を予測するために,教師なしの事前学習を行うことで,表形式のデータに対して初めて大幅な性能向上

DNNド素人の自分には下記の点で意味不明でした

  • DNNで決定木?
  • しかも特徴量の重要度まで分かっちゃう?

すでに様々な記事でこの問題に関する記事は多く散見されるので、ここではいつもどおり備忘用に自分のメモを残します。

DNNが決定木に化ける

NNを決定木にしよう!という試みはここ最近活発のようです。
既にニホンゴの解説記事が丁寧に書かれていたので拝見。

TabNet中の引用論文は上記以外にも含まれるが、コンセプトとしては以下のテクで決定木らしい事ができると理解した

  • 1. 出力層に限定せす、ある層に注目する。その層は手前の全結合層からaffinされているとする
  • 2. (ある入力値を通過させた際に、注目した全結合層の値)*(重み)すれば、各出力の値が産出されて、結果的にsoftmaxされてノードが選択される。
  • 3. 決定木も、ある値を入力して枝をたどるときには、分岐条件にしたがってノードが選択される。
  • 4. あれ、2と3似てない? 重みが分岐ルールに相当するってことになる。

上記ルールだと、ある層のある親ノードと子ノードの問題なので、最終的には親となりうる全ノードに対して適用する必要があり、
算出された値からノードの選択方法として以下の様なことが考えられる。どちらがベターかはわからない。

  • Hard: 内積の値を全ノードで計算。全体でみて確定的に選択。
  • Soft: ソフトマックス値(Sum=1)を使用して子ノードの選択確率を伝搬し、確率的に選択

これだと全結合層を連結させたDNNは決定木のように扱うことが可能と理解してしまうが大丈夫であろうか?

この論文ではちょっとアプローチが違う気がするが、上記を頭に入れていざ論文。

論文中のGBDTになるぜ!の説明図がある。この図はモデルを簡略化したもの。

  • 入力をマスキングしてピックアップ
  • 全結合に掛けて結果を出す
  • 図は二個だけど、このセット数は任意に設定可能で、結果を全て足す。
    • 足した結果を行列式で見ると、重みによって決定境界が引かれているように見える
  • 実は。。
    • FCと次のMaskの間にはattentionn機構があって、前回の結果が共有されている。
    • マスク四角には後方向に掛けてPriorityという係数が設定されて、各セットの重みも制御している

一見すると、各セットで独立している様に見えるが、読まないと理解できない。
個人的には以下の点でGBDTだ!って言っていると解釈した

  • マスキングとセットの大量生成がGBDTのRandomForest的部分をクリア
  • FCと次のMaskの間にはattentionがBoosting部分をクリア
    • 出力「全て足す」の部分も同じ
  • Gradientの部分がわからない。
    • 出力「全て足す」の部分が微分可能で・・・ってのはそもそも線形結合だし違うのでは?
    • Priorityは計算上、前回の結果を反映していないのでGradientとも言えない
    • 論文最後に、これはスパースなdataの場合に有効。と一言。なぜかはわからない。

これしか説明できないんのだよね・・

ぐずってもアレなので実装を拝見する

先述の通り、過去にいくつもの実装があるため拝見しました。
しかし残念なことに、PyTorchでは論文にかなり近い実装あったのですが、
TensorFlowでは不十分な実装が多かったです。

大事な部分を列挙します。

全体

四角が先の図の「FC」に相当。実装するときは、下記の四角を1単位としたほうが都合が良い。以降黄色枠の見方で追う。今回の肝は下記2点

  • FeatureTransformerは多層のNN。ここで特徴量に基づく値が出力(決定木の木に相当)。
  • AttentionTransformerはFeatureTransformerの結果を踏まえて(過去の結果を踏まえて)、入力された特徴量の取捨選択を実施。

で、特徴量重要度って?

AttentionTransformerは一般的にいうAttentionと同義で、出力結果は注目すべき特徴量is何。
言い換えれば、これが特徴量の重要度とも言える。しかし、ここでの結果はあくまで「あるStep」の結果なので、局所的な特徴量重要度。

一方、冒頭にいった大域的重要度は図中の赤丸importanceを各Stepで集約・平均したものと定義。
dataが通過したあとの係数を平均することで、より一般的な重要度が算出できるということだと理解。

ここで思ったのが、おなじみの共変性の問題。
Globalの特徴量重要度は上記の文から、がっつり共変性ありそうな気がします。

分けて、実装かく