[논문 리뷰] VkD : Improving Knowledge Distillation using Orthogonal Projections
[논문 리뷰] VkD : Improving Knowledge Distillation using Orthogonal Projections
VkD : Improving Knowledge Distillation using Orthogonal Projections
CVPR 2024
Roy Miles, Ismail Elezi, Jiankang Deng
Huawei Noah’s Ark Lab
[paper]
1. Abstract & Introduction
기존 연구 문제점
- 전통적인 KD는 특정 작업 및 모달리티에만 제한적으로 적용 가능.
- 기존 특징 Distll 방법은 비싼 relational object와 메모리 뱅크에 의존하여 높은 계산 비용 발생.
- 대부분의 특징 KD 파이프라인이 휴리스틱(경험적 간단한 접근법) 디자인에 의존하여 새로운 통찰력을 제공하지 못함.
- Task별 추가 보조 로스가 KD 목표와 충돌하여 학생 모델의 성능 저하를 초래함.
제안 방법
- 새로운 직교 프로젝션 층을 제안하여 학생 백본에 증류되는 지식을 극대화(너무 강력한 프로젝션 층은 증류에 방해됨-일종의 shortcut제공 & 학습할 때만 사용함).
- 특징 유사성을 보존하여 프로젝션 층이 학생 표현을 변경하지 않도록 함.
- 작업별 정규화 단계를 도입하여 보조 로스를 KD 로스에 통합.
- 데이터가 제한된 이미지 생성에서 화이트닝 단계를 통해 특징 다양성을 장려.
2. Related work
Layer reparameterisation
- 가중치를 제약하여, favourable property를 구하기 위한 기술로 많이 사용됨
- 그 중 직교행렬은 cheap controllable 디퓨전 모델을 파인튜닝 가능하게함
3. Orthogonal Projection
3.1. Why use orthogonal projections?
- 목적: projection layer(P)가 feature extractor와 공유되지 않는 새로운 표현을 학습할 가능성을 완화하는 것 (projection layer는 학습 중에만 쓰기 때문) -> 프로젝션 된 피처를 매칭하기보다 feature extractor 자체를 teacher와 일치시키고자 함 -> 프로젝션을 통해 구조적 정보를 보존하는 방식을 제안
- 구조적 정보는 커널 행렬 K로 설명되며, 이 행렬은 배치 내 모든 특징 간의 쌍별 유사성을 캡처함, 이를 보존하는 것이 목표
- H는 힐베르트 공간(무한 차원 공간에서 정의되는 수학적 구조)으로, 두 벡터 간의 가장 일반화된 내적이라고 볼 수 있음
- 이러한 커널 함수는 테일러 급수 전개를 통해 다차원 내적의 합으로 근사할 수 있음. 따라서 구조적 정보를 보존하기 위해서는 내적을 보존하는 P가 필요함
- 두 샘플 간의 구조적 정보를 유지하기 위해서는 P의 전치와 P의 역행렬이 같아야 함을 확인할 수 있음
- 하지만 학생 차원 \(d_s\)와 선생 차원 \(d_t\)가 다르기 때문에, P([\(d_s, d_t\)])는 정사각형 형태가 아니며, 따라서 표준 역행렬이 존재하지 않음. (두 차원이 같을 경우 special orthogonal group \((SO(d_s_))\)로 표현)
- 구조적 정보를 유지하기 위해 오른쪽 역행렬을 선택(\(PP^T= I_{d_t}\) 이를 만족하도록 P를 변경하는 것을 뜻함)하여 P를 orthonormal rows 행렬((각 행이 서로 직교하고, 크기가 1인 행렬)로 구성함
- 오른쪽 역행렬을 선택하는 것은 복잡한 과정으로 이를 재매개변수화를 사용해 딥러닝으로 해결(왼쪽 역행렬 선택도 가능하지만, 재매개변수화 효율을 위해 오른쪽 역행렬을 선택)
- 이러한 orthonormal rows 행렬의 전치는 stiefel matrix manifold \(V_{d_{t}}(R^{d_{s}})\)로 표현할 수 있음, 이는 \(d_s\) 차원에서 \(d_t\)개의 서로 직교하는 벡터의 집합을 의미함
- 이러한 stiefel manifold는 형태가 smooth하여서 재매개변수화를 사용해 경사하강법과 같은 테크닉을 적용하기에 적합함
3.2. Orthogonal reparameterisation
- P의 직교성을 보장하기 위한 여러 방법 중 Matrix Exponential 방법을 사용
- 매개변수 W를 skew-symmetric matrix(\(W + W^T = 0\)을 만족하는 행렬)로 정의
- \(exp(W) · exp(W)^T = exp(W+W^T ) = exp(−W^T +W^T ) = exp(0) = I\) 이 성질을 봤을 때, exp(w)가 직교 행렬임을 확인할 수 있음.
- 이는 하나의 지수 행렬만 계산하면 되어서 효율성이 높으며, 지수는 Pad´e근사법을 활용하여 효율적인 구현이 가능함
- 따라서 이를 활용하여 \(d_t\) 차원의 \(SO_{d_{t}}\)를 구성하고, 마지막 \(d_t - d_s\) 행을 제거하면, \(V_{d_{t}}(R^{d_{s}})\)를 구성할 수 있음.
- 즉 매개변수 W를 통해 stiefel manifold로 표현되는 orthonormal를 구성할 수 있고, 이것은 학생 피처의 구조적 정보를 유지하면서 선생 차원으로 프로젝션할 수 있음
3.3. Orthogonal projections minimise redundancy
- 직교 프로젝션은 데이터의 거리 개념을 보존하고, 왜곡 없이 변환을 수행하여, 정보의 손실을 최소화함
3.4. Introducing domain-specific priors
- 보조 손실(task loss)를 사용하면, 모델 학습에 도움이 되지만, distill loss와 충돌하여, distill을 방해할 수 있음
- 이를 해결하기 위해 보조 손실 없이 implicit하게 도메인 지식을 kd 목표에 통합하는 일반화된 정규화 프레임워크를 제안함
Standardisation improves model convergence.
- 표준화는 입력의 무작위 변동에 대한 kd loss의 강건성을 높이는데 효과적
Whitening improves feature diversity.
- $X_{\text{whitened}} = V \Lambda^{-1/2} V^T X_{\text{centered}}$
- whitening은 데이터의 분표를 균일하게 만드는 과정으로(중심화된 데이터 생성 -> 공분산 행렬 계산, 고유값 분해), 서로 다른 특성 간의 상관관계를 제거함 -> generative 태스크에서 중요함
- 화이트닝 수행 후, L2 로스 적용
- \(Z^t\)는 화이트닝 되었기 때문에, \((Z^t)^T(Z^t) = I\)임
- const와 \(\lambda\)는 모델 파라미터와는 무관한 상수 >= 3
- C는 학생과 교사의 특징 간의 거리를 캡처하는 유클리드 교차 상관 행렬이 됨
- 결과적으로 화이트닝 제약 조건 하에서 L2 loss를 최소화 하는 것은, 교차 상관 행렬의 비대각 항을 최대화하도록 유도됨 -> 모든 툭징으 교사 모델에 대해 비상관적으로 유도 -> “decorrelation” 과정을 “feature diversity”의 증가 설명, 즉 화이트닝을 활용하여 특징 간의 독립성을 높이고, 이를 통해 다양한 특징을 효과적으로 학습하도록함
4. Experiments
Implementation details
- Nvidia v100 2대로 실험
- ImageNet 실험: -DeiT와 동일한 학습률과 하이퍼파라미터 사용
- Object detection 실험:
- ViDT 방법과 동일한 방식을 사용하되, 원래의 토큰 매칭 손실(token matching loss)을 𝑉_𝑘 𝐷로 대체
- 이미지 생성 작업:
- KDDLGAN 방법과 동일한 방식을 사용하되, 보조 다양성 손실(auxiliary diversity losses)을 제거하고, 대신 교사 표준화(teacher standardisation) 또는 whitening으로 대체
- 텍스트 인코더에서의 distillation을 제거하여 cost를 줄임
Data efficient training of transformers
- 동일 파라미터 기준 가장 우수한 성능을 보이며, 학습 epoch을 길게 가져갈 필요가 없음을 보임
- 또한 패치 토큰을 통한 직접적인 증류 방식이 유효함을 보임
- CNN -> Transformer
- 토큰 매칭 방식과 비교하였을 때, 우수한 성능
- 더 큰 교사 모델을 사용할 때에 최고의 성능 -> 교사와 용량의 차이가 큰 경우에도 성능이 제한되지 않음을 보여줌
- Transformer -> Transformer
Data limited image generation
- 명시적인 추가 로스 없이, 피처 whitening을 통해 다양한 이미지 생성이 가능함을 보임
- 또한 데이터가 제한된 상황에서 더 큰 향상이 있으며, 이는 학습 데이터가 부족할 때, 피처 다양성이 훨씬 중요함을 보임
- Transformer -> CNN
Ablation study
- MLP와 projector ensemble이 짧은 학습 epoch에서는 Linear 레이어보다 효과적
- 긴 epoch에서는 Linear 레이어가 효과적 -> 표현력이 강한 프로젝션이, 학생의 feature extractor와 공유되지 않는 새로운 표현을 학습하기 때문에 성능의 정체가 발생
- Orthogonal projection은 정확도가 높고, 수렴속도가 빠름
- 대체로 Orthogonal projection이 성능 향상의 원인
- 생성 작업에서는 정규화의 필요성이 대두됨
- CNN을 teacher로 Transformer로 증류를 할 때, CNN의 inductive bias가 Transformer에 효과적임을 보여줌
- 다른 KD 방식에 비해, 중요한 객체에 attention이 잘됨을 시각화
- CNN->TF, TF->TF, TF->CNN 어떤 방식에도 KD가 문제없이 적용됨을 보여주었음
Review
- 피처를 그대로 따라하고, header만 재사용하면 된다는 논문(Knowledge Distillation with the Reused Teacher Classifier CVPR 2022)과 일맥상통하는 부분이 있다고 보여지며, 좀 더 일반화된 방식이며, 이론이 뒷받침된것으로 생각됨
- 선생과 학생의 아키텍처에 무관하게 적용 가능하다는 것이 장점
- Data limited image generation(GAN 기반)은 이제 의미가 있는 연구인지는 잘 모르겠다
이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.