計算機視覺中的知識蒸餾
作者:ppog@知乎(已授權轉載)
編輯:CV技術指南
原文:https://zhuanlan.zhihu.com/p/497067556
前段時間熬完畢設的工作,趁著空閑想寫一篇關于知識蒸餾的博客,這是本人讀研期間的一個研究方向,但這篇博客不會過于深入,內容大概簡短說說自己對于知識蒸餾的一些看法,大多數內容來源于四月份看到的兩篇paper。鄙人愚見,有不當之處歡迎批評!
文中涉及到的三篇論文
《Distilling the Knowledge in a Neural Network》
paper:arxiv.org/pdf/1503.0253
code:github.com/labmlai/anno
《Solving ImageNet: a Unified Scheme for Training any Backbone to Top Results》
paper:arxiv.org/pdf/2204.0347
code:github.com/Alibaba-MIIL
《Decoupled Knowledge Distillation》
paper:arxiv.org/abs/2203.0867
code:github.com/megvii-resea
1、知識鋪墊one hot 編碼
one-hot 編碼(one-hot encoding)類似于虛擬變量(dummy variables),是一種將分類變量轉換為幾個二進制列的方法,即一種硬編碼形式,類似非黑即白。其中 1 代表某個輸入樣本屬于該類別。
深度學習領域中,通常將數據標注為hard label,但事實上同一個數據包含不同類別的信息,直接標注為hard label無法顯示圖像數據間的相關性,例如分類任務中,數據樣本(下圖)的hard label是【sheep:1】,而實際上,樣本中包含了一條狗,對應的soft label可能是【sheep:0.90;dog:0.10】。
基于上述事實:
hard label會根據照片,告訴我們這就是羊,其他都不是;
soft label會告訴我們,這張照片大概率是羊,存在一定概率是狗。
但在實際應用中,兩者均有其所長:hard label雖然更容易標注,但是會丟失類內、類間的關聯。而soft label能給模型帶來更強的泛化,攜帶更多的信息,但是獲取難度會比hard label大。
總的來說,兩者都屬于知識遷移的一種,知識蒸餾是模型層面的遷移方式,而遷移學習是數據層面的遷移方式。
具體而言,兩個在一定程度下都可以實現漲點,以ImageNet-1K、ImageNet-21K、ResNet18、ResNet31為例(假設驗證集恒不變):
對于遷移學習,我們使用ResNet18在ImageNet-21K上進行預訓練,訓練完后將模型遷移到ImageNet-1K上微調,在驗證集不變的情況,精度會更高。
對于知識蒸餾,我們使用ResNet32作為Teacher模型在ImageNet-1K上進行訓練,ResNet18作為Student模型同樣也在ImageNet-1K上訓練,但會引入訓練完后的Teacher模型做監督,往往精度也會提高。
但兩種方式都會帶來一些問題,例如訓練周期更長,更大的計算開銷,更嚴重的資源占用等等。
《Distilling the Knowledge in a Neural Network》是知識蒸餾的開山鼻祖,于2015年提出,目前引用量快超過10k。其提出來的帶溫度的kl散度損失是最早的分類算法蒸餾方案,由于是基于logits的蒸餾方式,易于復現,后續也有許多在KL散度上進行改進的版本。
Knowledge Distillation 的整體示意如上圖所示(基于logits):
Teacher model:結構較為復雜,特征提取能力更強的大模型,如ResNet31
Student model:結構較為簡單,特征提取能力一般的小模型,如ResNet18
Hard label:輸入數據所對應的類別,上文開頭解釋過了,常規的訓練一般都是使用的Hard label
Soft label:輸入數據通過Teacher模型softmax層的輸出,蒸餾訓練附加的loss基于此得來
distill loss:蒸餾采用的損失可能是KL、MSE、CE等,該論文采用的是基于溫度T的KL Loss
KD常見步驟
圍繞這幾個基本點,共進行步驟如下(假設數據集為cifer):
① ResNet31在cifer數據上訓練得到的教師模型
② 將教師模型的prediction軟化,即輸入數據通過teacher model所得到的softmax層的輸出:
③ 得到軟化的預測向量后,通過KL散度損失進行下一步計算:
可以看下代碼實現:
# y_s: student output logits
# y_t: teacher output logits
# T: temperature for KD
# teacher model: resnet31
# student model: resnet18
class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network"""
def __init__(self, T):
super(DistillKL, self).__init__()
self.T = T
def forward(self, y_s, y_t):
KLDLoss = nn.KLDivLoss(reduction="none")
p_s = F.log_softmax(y_s/self.T, dim=1)
p_t = F.log_softmax(y_t/self.T, dim=1)
loss = KLDLoss(p_s, p_t) * (self.T**2)
return loss
④ 在計算出蒸餾的loss后,將這個kl_loss附加在原始的分類損失(假設是CE loss)上:
在經過知識蒸餾的操作后,模型精度得到了提升,但當時開展的相關實驗比較少,畢竟是在2014年,各方面條件都有所限制,且文中作者也沒有十分詳細地解釋蒸餾能讓模型提升的具體原因。
時間來到了2022年,在4月7日,阿里達摩院在arvix上掛上了《Solving ImageNet》,該論文主要針對目前的計算機視覺模型,提出通用的訓練方案USI,并且該方案主要基于KD蒸餾的訓練方式。
不過在我看來,該論文展示了許多豐富的實驗及結果,并且驗證和解釋了為何KD是有效的,更像是對14年提出的KD進行詳盡的補充。
文中提到,目前的計算機視覺模型大致下可以分為四類:
類似ResNet的常規CNN模型(ResNet-like)
面向移動端的輕量模型(Mobile-oriented)
Transformer模型(Transformer-base)
僅包含MLP的模型(MLP-only)
該作者對上述四種架構的計算機視覺模型抽樣進行了實驗,有意思的是,使用基于KD方式的訓練方案的模型在Top-1上均獲得了不同程度的提高,特別是Mobile-oriented類的輕量模型。
為了更深入地了解KD對模型結果的影響,作者在下圖中展示了一些教師模型預測的標簽,與ImageNet真實標簽的對比。
圖片(a)包含了大量明顯的釘子,教師模型的預測是99.9%,而第二和第三個預測也與釘子(螺絲和錘子)相關,但概率值可以忽略不計。
圖片(b)中包含了一架客機,教師模型的最高預測是客機(83.6%)。然而,教師模型也有一些不能忽視的概率(11.3%)。這并非是錯誤,因為飛機上有機翼。這里的教師模型減輕了實際情況與真實標簽相互排斥的情況(即要么是1,要么是0),并提供了關于圖像內容更準確的信息(打個比方,前面提到的一張圖基本都是羊,但有一條狗,數據集的分類標簽是羊,但teacher教師預測時會留出部分概率給了狗)。
圖片(c)中包含了一只母雞。然而,母雞的信息并非很明顯,教師模型的預測反映了這一點,通過識別出一只概率較低的母雞(55.5%),還給出了一定的概率給公雞( 大約8.9%.)。雖然這是教師模型的錯誤,但實際上就算是人,這么小的目標似乎也很難一下子分得清。
在圖片(d)中,教師模型認為真實標簽是錯誤的。真實標簽是冰棍,而教師模型預測概率最大的是狗。作者認為教師模型的預測反而是對的,因為狗在圖片中的信息更為突出。
從上面的例子中可以看到,教師模型的預測比簡單( 0或1)的真實標簽包含了更豐富的信息,soft label解釋了類別之間的相關性。不僅如此,KD更能代表增強過后圖像的正確信息,能更好處理strong augmentations的問題。由于上述提到的原因,與僅使用hard label的訓練相比,使用教師模型的soft label進行訓練會提供有更有效的監督,訓練會變得更有效、更穩健。
上邊講到,KD有作用,但究竟是哪部分起作用,作用多大,是否存在負優化,值得思考!
在今年的3月16日,曠視對KD(KL Loss)進行了更加深入的剖析,提出了解耦蒸餾(《Decoupled Knowledge Distillation》,DKD),這篇文章很精彩,對14年提出的KD(KL Loss)進行了多方位的解析,也開展了許多實驗。
如上圖所示,研究者將 logits 拆解成兩部分,藍色部分指目標類別(target class)的 score,綠色部分指非目標類別(Non-target class)的 score。并且將KD重新表述為兩部分的加權和,即 TCKD 和 NCKD。
上述定義和數學關系將幫助我們得到 KL Loss 的新表達形式:
對于公式的補充解釋:
更有說服力的實驗
為了觀察TCKD 和 NCKD 對蒸餾性能的影響,作者做了大量實驗,并試圖通過實驗剖析TCKD 和 NCKD 的作用。
上圖為TCKD 和 NCKD在CIFAR-100 上進行的實驗,作者初步得出以下結論:
同時使用 TCKD + NCKD = KD 的蒸餾方式,Student模型均獲得不同程度的提升;
單獨使用 TCKD 進行蒸餾,會對蒸餾效果產生較大的損害,原因在于高溫系數(T)會導致損失附加上很大的梯度,增加非目標類的 logits ,這會損害學生預測的正確性;
單獨使用 NCKD 進行蒸餾,和 KD 效果差不多;
基于上述結論,是否 NCKD 更加有效,而 TCKD 存在負優化?作者給出了進一步的探討。
作者認為 TCKD 受限于數據集的難易程度,假設一個樣本經過教師模型后輸出概率是0.99,說明這個樣本是易樣本,數據集是容易分辨的,而當概率只有0.75,甚至是0.55,那么樣本會陷入到模棱兩可的狀態,模型也沒有把握認定它就是所謂的那個它(你那么愛它,為什么不把它留下),數據集難度增加。
作者補充了以下三個實驗:更重的數據增強;更多的噪聲;更復雜的數據。
1、更重的數據增強
上表顯示Teacher模型為ResNet32×4,Student模型為ShuffleNet-V1和ResNet8×4的實驗結果,在使用 AutoAugment數據增強方法的情況下,訓練集難樣本系數增大,此時使用 TCKD 可以達到較大的提升。
2、更多的噪聲
而通過引入噪聲,當噪聲比例增大,TCKD 的提升程度也加強。
3、更復雜的數據
使用ResNet34作為Teacher模型,ResNet18作為Student模型,作者發現學生模型的Top-1增加了0.32個點。
最后,作者給出的結論是,通過嘗試各種策略來增加訓練數據的復雜度(例如重的數據增強、更多的噪聲、困難的任務)來證明 TCKD 的有效性。結果證實,在對更具挑戰性的訓練數據進行知識蒸餾時,訓練樣本“復雜度(難度)”的提升對于 TCKD 可能更有增益,說明 TCKD 對于數據集中復雜任務的監督能力更強。
而上上上部分,作者也證實了NCKD 能力出眾,這也反映了一個事實:說明非目標類之間的知識對logits的蒸餾方式至關重要,它們可以比喻為能力出眾的“暗部成員”(知道卡卡西嗎?),論文中稱之為“暗知識”(dark knowledge)。
如何理解?大家可以把目標類別的logits看作是light knowledge,按照我們慣有的思維,目標類別是最重要的,我想要識別出一條狗,那么我就會找一大堆關于該目標類別的樣本,不斷填充和豐富它的logits信息,而非目標類別則顯得不那么重要,因為我們想要kill的名單中沒有他們,但不可置否,dark knowledge對于模型泛化性也非常關鍵。
依據 Teacher 模型預測的置信度,作者對cifer訓練集上的樣本做了排序,根據排序結果對數據集進行切分,置信度0.5-1為一塊,置信度為0-0.5為一塊,實驗結果如下:
在前 50% 的樣本上使用 NCKD 可以獲得更好的性能,這表明預測良好的樣本所攜帶的知識比其他樣本更豐富。然而,預測良好的樣本的損失權重被教師的高置信度所抑制。這也說明了,置信度高的樣本對蒸餾的效果更加顯著,應當采取措施讓它們不被抑制。
分類任務
作者使用DKD和KD進行對比,效果都要優于KD(KL Loss)的方式,在不同模型上實現了1-2,甚至是3個點的提升。
并且,作者對一些細節也進行了補充,通常a設置為1時效果較好,而實際應用中變動較大的為Beta,當具體調為何值,需要根據實際的業務數據進行實驗。
檢測任務
作者使用了Faster rcnn作為baseline,通過替換不同的backbone以此作為teacher和student,可以看出,DKD的方式帶來的提升均超過了原始KD的方式,而將DKD與基于Feature蒸餾結合起來組成的DKD+ReviewDKD提升更大。這也證明了,檢測任務十分依賴于feature的定位能力,而logits這種high level的信息并不具備這種能力,這也使得基于logits的蒸餾方式效果差于feature的蒸餾,但總的來說,KD的解耦型DKD還是展示了更加優越的性能。
這篇博客從三個層面講述了KD是什么?為什么有效?突然想寫這篇博客,原因在于四月份看到的兩篇論文解答了我之前在這個方向上的不少疑惑,隨整理出來。但由于本人并未涉略過深,仍會有很多理解不足的地方,也歡迎各位大佬批評指正!
參考文獻[1] pprp:知識蒸餾綜述:代碼整理
[2] medium.com/analytics-vi
[3] 從標簽平滑和知識蒸餾理解Soft Label
[4] [論文閱讀]知識蒸餾(Distilling the Knowledge in a Neural Network)
[5] Distilling the Knowledge in a Neural Network 論文筆記
[6] oldsummer:2021 《Knowledge Distillation: A Survey》
[7] CVPR 2022|解耦知識蒸餾!曠視提出DKD:讓Hinton在7年前提出的方法重回SOTA行列!
[8] 阿里巴巴提出USI 讓AI煉丹自動化了,訓練任何Backbone無需超參配置,實現大一統!
本文僅做學術分享,如有侵權,請聯系刪文。
*博客內容為網友個人發布,僅代表博主個人觀點,如有侵權請聯系工作人員刪除。







