포스트

[논문 리뷰] DearKD: Data-Efficient Early Knowledge Distillation for Vision Transformers

[논문 리뷰] DearKD: Data-Efficient Early Knowledge Distillation for Vision Transformers

DearKD: Data-Efficient Early Knowledge Distillation for Vision Transformers
CVPR 2022
Xianing Chen, Qiong Cao, Yujie Zhong, Jing Zhang, Shenghua Gao, Dacheng Tao
ShanghaiTech University, JD Explore Academy, Meituan Inc., The University of Sydney, Shanghai Engineering Research Center of Intelligent Vision and Imaging, Shanghai Engineering Research Center of Energy Efficient and Custom AI IC
[paper]

1. Abstract & Introduction

fig1

기존 연구 문제점

  • 트랜스포머는 강력한 성능을 보이지만, 다량의 데이터를 필요로함
  • CNN의 강력한 inductive bias(locality, weight sharing으로 인해)를 트랜스포머에 적용하려는 연구
    • convolution을 트랜스포머에 삽입하는 방식은 트랜스포머 본질적인 특성을 파괴할 수 있음
    • DeiT는 CNN에서 트랜스포머로 지식을 증류허지만, 초기 트랜스포머 층이 inductive bias를 캡처하기 어려움

제안 방법

  • 2 stage 학습 프레임워크
    1. 트랜스포머의 ‘early’ layer와 학습의 ‘early’ 단계에서의 KD를 수행
      CNN의 로직과 중간 층 모두에서 증류하여, 트랜스포머의 early 층에 명시적인 학습 신호를 제공, 또한 Multi-Head Convolution-Attention(MHCA) 설계
    2. KD를 수행하지 않고, 트랜스포머가 자체적으로 inductive bias를 학습할 수 있도록하여 트랜스포머의 강력한 성능을 활용하도록 함
  • Deepinversion(주어진 모델의 특성을 활용하여, 새로운 이미지를 만들어내는 방식) 기반의 Boundary-Preserving intro-divengence loss를 도입하여, 데이터 없는 설정에서도 성능 향상

Related Works

Data-Free KD

  • 실제 데이터를 접근하지 않고, 복잡한 교사 모델로 부터 학생 모델을 학습하는 방법
    • GAN 기반 방법
    • Prior 기반 방법
  • 두 방법은 모두 mode collapse 문제를 가지고 있으며, 저자는 Deepinversion을 Boundary-Preserving intro-divengence loss를 도입하여 다양한 샘플을 생성

3. Data-efficient Early Knowledge Distillation

fig2

Inductive Biases Knowlede Distillation

  • 기존 연구는 CNN의 초기단계에서, 텍스쳐와 같은 local 패턴을 잘 포착할 수 있음을 밝힘
  • 따라서 초기 트랜스포머 층에 이러한 inductive bias를 명시적으로 제공하는 것이 데이터 효율성을 개선하는데 중요함
  • 후반 단계에선 이러한 가이드가 트랜스포머 자체의 표현 능력을 해칠 수 있기 때문에, 초기단계에서만 KD를 사용

3.1. DearKD: Stage 1

Multi-Head Convolutional-Attention (MHCA)

fo1 fo3

  • 기존 연구는 N개의 헤드와 3차원 이상의 상대적 위치 인코딩을 가진 MHSA가 Convolutional layer를 표현할 수 있음을 증명하였음
  • 저자가 제안하는 MHCA는 MHSA에서 상대적인 위치 인코딩을 추가한 것(\(v^{(h)}\)는 각 헤드에 대한 상대적 위치 인코딩, \(r_{ij}\)는 query i와 key j간의 상대적 위치를 나타내는 값)
  • MHCA는 MHSA와 달리 2 파트로 구성됨
    • content part는 non-local semantic 정보를 학습
    • position part는 local detail을 학습

Early Knowledge Distillation

fig3 fo4 fo5

  • 학생모델 트랜스포머: \(H_S \in \mathbb{R}^{l \times d}\), 선생모델 CNN: \(H_T \in \mathbb{R}^{h \times w \times c}\)
  • 형태가 다르기 때문에, 직접적으로 distill 할수 없음
  • \(H_S\)를 \(H_T\)로 reshape, Depth-wise convolution, LayerNorm, ReLU를 적용하는 aligner를 설계
  • CNN의 중간 층에서 트랜스포머르 KD를 적용한 첫번째 연구임
  • 교사 모델의 hard label을 이용하여, CE기반 로직 증류를 추가로 사용

fo6

  • 최종 로스는 학생 모델의 예측과 실제 레이블로 측정한 CE 로스와, teacher와의 logit loss 그리고 CNN의 중간층에서 학생 모델 트랜스포머로의 KD 로스
  • \(L_{CE}\)를 상당히 헷갈리게 사용하고 있다.

3.2. DearKD: Stage 2

Transformers Instrinsic Inductive Biases Learning

fig4 fo7

  • stage 1에서의 상대적인 위치 인코딩은 그대로 적용
  • non-local 표현을 형성하기 위한 reception filed를 확장하는 학습을 수행
  • 그림 4의 average attention distance란, 특정 픽셀에서 다른 픽셀까지의 평균적인 attention 거리로, 낮을수록, local정보에 집중하는 것을 의미함
  • 즉 stage 1의 후반부에선 수렴을 하다가, stage 2에서 다시 가파르게 증가함으로써, non-local 표현을 형성한다고 해석할 수 있음
  • 수식으로는 일반적인 CE loss를 적용

4. DF-DearKD: Training without Real Images

DeepInversion

fig5 fo8 fo9

  • DearKD의 데이터 효율성을 활용하기 위해, 실제 이미지가 없는 극단적인 설정에서 이를 탐구.
  • DF-DearKD는 DearKD에 추가적인 image generation component 있는 것
  • 학습된 CNN teacher가 있고, 무작위 입력값 x와, 레이블 y가 주어지면, DeepInversion은 8번식을 최적화하여 이미지를 합성함
  • \(L_{\text{diversity}}\)는 중복된 이미지를 피하기 위한 다양성 손실
  • \(R\_x\)는 비현실적인 이미지를 피하기 위핸 이미지 정규화 항으로, x의 총 분산과 L2 norm을 패널티로 부여하는 \(R\_{\text{prior}}\)과, 현재 배치의 평균과 분산을 학습된 통계로 패널티를 부여하는 \(R\_{\text{BN}}\)으로 구성

Boundary-preserving intra-divergence loss

fig6

  • 기존 방식은 임베딩 공간의 over clustering문제로 어려운 샘플과 모호한 샘플을 생성함
  • 이러한 문제를 해결하기 위해, easy positive 샘플이 잠재 공간에서 다른 샘플과 멀어지도록 하면서, 클래스 경계의 영향을 받지 않도록 하는 로스를 제안함

fo10

  • \(x\_a\)를 앵커 이미지라고 했을 때, easy positive sample \(x\_{ep}\)는 같은 클래스 중 잠재공간에서의 유클리드가 거리가 가장 가까운 샘플임
  • 같은 방식으로 같은 클래스 중 거리가 가장 먼 hard positive sample \(x\_{hp}\), 다른 클래스 중 가장 거리가 가까운 hard negative sample \(x\_{hn}\)을 구할 수 있음

fo11

  • 따라서 이 수식은, ep 샘플을 거리를 최대화하여, intra-calss diversity를 최대화하기 위한 로스식임

fo12 fo13

  • 하지만 너무 클래스 경계를 넘어가도록 하지 않기 위해서, 앵커와 hp사이의 거리 \(dist\_{ap}\)와 앵커와 hn사이의 거리 \(dist\_{an}\)이 마진 만큼의 차이를 가지도록 하기 위한 triplet 로스를 제안함
  • 최종적으로 \(\alpha \_{ep}\)는 50, \(\alpha \_{triplet}\)은 0.5를 사용함

5. Experiments

5.1. Implementation details

DearKD

  • Baseline: DeiT-Ti, DeiT-S, DeiT-B로, 각각의 헤드 수를 12, 12, 16으로 증가하여 convolution 능력을 높인 DearKD-Ti, DearKD-S, DearKD_B를 제안
  • Input & 구조: 224X224 입력 이미지를 16X16 패치로 임베딩 8개의 MHCA, 4개의 MHSA 블록을 통해 전파
  • Teacher: RegNetY-16GF
  • Epoch: 300(1stage 250, 2stage 50) with AdamW
  • EPoch2: 모델 뒤에 1000이 달린 경우 (DearKD-Ti-1000), 1stage 800, 2stage 200
  • Augmentation: Mixup, Cutmix, Random Erasing, Random Augmentation
  • GPU: A100 8대

DF-DearKD

  • 다중 해상도 최적화: 112X112 해상도로 2000회 반복 최적화, 224X224해상도로 2000회 반복 최적화
  • hyperparameter:
    • \[\alpha_{TV} = 1 \times 10^{-4}\]
    • \[\alpha_{L2} = 1 \times 10^{-5}\]
    • \[\alpha_{BN} = 5 \times 10^{-2}\]
    • \[\alpha_{ep} = 50\]
    • \(\alpha_{triplet} = 0.5\) 이긴한데.. 밑에 2개 제외하면 위 로스식에 명확히 명시된 곳이 없음. 특시 첫번째꺼는 어딘지 모르겠음

5.2. Ablation Stduy

t1

  • MHCA만 추가하였을 때도, locality의 도입으로 약간의 성능 개선이 있음
  • MHCA없이 \(\text{L}_{hidden}\)을 사용하면 CNN과 트랜스포머의 구조 차이로 성능이 감소함
  • 최종적으로 트랜스포머 자체의 inductive bias 학습 또한 중요함을 강조

5.3. Analysis of Data Efficiency

t2

  • 동일 데이터 양 기준, 베이스라인보다 큰 마진으로 성능 격차를 보이며, 데이터 양이 줄어들수록, 격차가 더 커짐
  • 하지만 실제로 DeiT에서 제안하는 방식과 비교하였을 때는 그리 큰 차이는 아님

5.4. Comparison with Full ImageNet

t3

  • 유사한 모델 크기를 가지는 CNN과 ViT기반 모델과 비교했을 때, 우수한 성능을 보이고 있으나, EfficientNet이 더 효율적이면서 유사한 성능을 보임

t4

  • Downstream task로 fine-tuning할 떄, 이 단계에서는 teacher없이도 우수한 성능을 보임. 좋은 일반화 능력을 가지고 있음

5.5. Performance of DF-DearKD

t5 t6

  • 제안 방법이 실제 이미지로 distill한 결과보다는 낮지만, scratch로 학습한 결과와는 유사하다고 주장. (DeiT-S는 유사하진 않음)
  • 생성 이미지만으로 유의미한 결과를 보였다고 생각하며, 기존 data-free 방식보다는 우수한 성능을 보임
  • LPIPS를 사용하여, 비교하였을 때, 다양성을 입증함

Review

  • CNN의 중간 층에서 트랜스포머에 KD를 적용한 최초의 연구로, 이를 위해 local 측면을 파악할 수 있도록 상대적인 위치 인코딩을 도입
  • 데이터 효율적인 측면은 우수하나, 성능 측면은 조금 아쉬움
  • 데이터-free 측면에서 효율적인 구조를 손보임. 논문을 리뷰한 현재 기준으로 봤을 때 Diffusion 모델의 발전으로 인해 여기서 사용한 DeepInversion 방식은 큰 의미는 없을 것 같음
  • 다양한 제약조건에 사용한 로스는 참고할만함
  • 오타나, 불명확한 표현이 많아서.. 생각보다 거슬렸음. 내가 쓸 논문도 리뷰어의 입장에서 보면 거슬릴 수 있으니, 논문 쓸 때 주의해야할 것 같음
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.