Soohyun’s Machine-learning

[Tabular] TabNet : Attentive Interpretable Tabular Learning 본문

Review of Papers

[Tabular] TabNet : Attentive Interpretable Tabular Learning

Alex_Rose 2021. 6. 2. 11:13

이미지를 클릭하면 해당 논문으로 이동합니다.

 

TabNet 라이브러리 깃허브 링크 : https://github.com/dreamquark-ai/tabnet

 

Abstract

TabNet uses sequential attention to choose which features to reason from at each decision step, enabling interpretability and more efficient learning as the learning capacity is used for the most salient features.

 

 

keywords

- interpretability

- self-supervised learning

- single deep learning architecture (for feature selection and reasoning)

- instance-wise feature selection

- soft feature selection with controllable sparsity in end-to-end learning (via sequential attention)

 

 

Architecture

TabNet은 먼저 Unsupervised pre-training을 한 다음에, Supervised fine-tuning을 한다.

 

self-supervised tabular learning (Unsupervised learning by masked self-supervised learning results in an improved encoder model for the supervised learning)

 

 

DNN 블록에서 mask 되는 부분

 

 

 

 

TabNet의 각 부분별 아키텍쳐

 

위의 아키텍쳐에 있는 Feature transformer와 Attentive transformer의 각각의 내부 모습

GLU : Gated Linear Unit

 

 

TabNet의 디코더 형태

 

 

GLU 설명 : https://paperswithcode.com/method/glu#

오른쪽 그림에서도 볼 수 있듯이 where input data is split in half along dim to form "a" and "b".

 

ⓧ : element-wise product

σ (small sigma) : sigmoid

 

 

 

 

Sparsemax normalization & Ghost batch normalization (GBN)

Sparsemax Normalization

 

We use an attentive transformer (Fig. 4) to obtain the masks using the processed features from the preceding step

 

a[ i - 1 ] : M[ i ] = sparsemax( P[ i - 1 ] h_(i) (a[ i - 1] ))

 

 

Sparsemax normalization (Martines and Astudillo 2016) encourages sparsity by mapping the Euclidean projection onto the probabilistic simplex, which is observed to be superior in performance and aligned with the goal of sparse feature selection for explainability. 

 

originated from sparsemax paper, 이미지를 클릭하면 논문이 보입니다

 

 

 

Ghost Batch Normalization (GBN) 

 

GBN을 쓰는 이유? 

  - "generalization gap"이라고 부르는 현상, 즉 배치사이즈가 클 때, 본 적이 없는 데이터에 대해서 더 나쁘게 동작하는 것을 해결하기 위해서 나왔다. (Batch Normalization은 인풋 데이터를 0 mean, 1 std. dev.로 노말라이즈 함)

레퍼런스 3번째의 GBN 예시코드

 

epsilon - log 씌웠을때 오류뜨는거 방지해주는 목적이랑 같은 목적으로 이 수식에서 존재함.

               파란색 괄호 안에는 epsilon이 안 들어가 있음

 

GBN과 해당 링크의 저자 (레퍼런스 5번째)가 NoisyBatchNorm (a method to add different noise to different slices of the batch)과 비슷한 결과를 내는지를 체크해보았는데, 실제로 비슷한 성과를 냈다고 한다.

 

 

 

Performance

 

FEATURE SELECTION - Performance on real-world datasets

 

TabNet vs. other feature selection-based DNN models

in AUC, compared between 6 synthetic datasets

+ TabNet performance is close to global feature selection - it can figure out what features are globally important.

 

 

Forest Cover Type dataset (Classification of forest cover type from cartographic variables)에서 효과

 

Sarcos dataset (Regressing inverse dynamics of an anthropomorphic robot arm)

TabNet's performance is on par with the best model from (Tanno et al. 2018) with 100x more parameters.

 

 

Interpretability in Real-world datasets

 

We first consider the simple task of mushroom edibility prediction. TabNet achieves 100% test accuracy on this dataset.

It is indeed known that "Odor" is the most discriminative feature - with "Oder" only, a model can get > 98.5% test accuracy. Thus, a high feature importance is expected for it. 

 

TabNet assigns an importance score ratio of 43% for it, while other methods like LIME (Ribeiro et al., 2016), Integrated Gradients (Sundararajan et al., 2017), and DeepLift (Shrikumar et al., 2017) assign less than 30%. 

 

 

 

SELF SUPERVISED LEARNING

Table 7 shows that unsupervised pre-training significantly improves performance on the supervised classification task, especially in the regime where the unlabeled dataset is much larger than the labeled dataset. 

 

그리고 아래 그래프에서 보이듯이 unsupervised pretraining을 했을 경우, 훨씬 빠르게 모델이 converge한다.

 


여기까지 논문 간단한 summary 입니다. 이외의 세부적인 내용은 implement를 해보면서, references를 찾아가면서 업데이트 예정입니다.

 

 

References

1) TabNet: Attentive Interpretable Tabular Learning : https://arxiv.org/pdf/1908.07442v5.pdf

2) From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification : https://arxiv.org/pdf/1602.02068.pdf

3) https://github.com/dreamquark-ai/tabnet

4) https://towardsdatascience.com/implementing-tabnet-in-pytorch-fc977c383279

5) Ghost Batch Normalization : https://medium.com/deeplearningmadeeasy/ghost-batchnorm-explained-e0fa9d651e03

Comments