深入理解防御蒸馏(Defensive Distillation)
防御蒸馏(Defensive Distillation, DD)是一种旨在提高深度学习模型对对抗性攻击(Adversarial Attacks)鲁棒性的技术。它由Papernot等人在2016年提出,核心思想是利用知识蒸馏(Knowledge Distillation)的方法,通过平滑模型的决策边界来减少对抗性扰动的有效性。
在模型部署环境中,即使是很小的、人类无法察觉的扰动,也可能导致模型输出错误的结果。DD通过训练一个“学生”模型去模仿“教师”模型在高温(High Temperature, T > 1)下生成的软标签(Soft Labels),从而使模型的梯度幅度变小,决策表面更加平滑,攻击者难以找到有效的扰动方向。
实验环境准备
我们将使用PyTorch和MNIST数据集,并引入torchattacks库来生成对抗样本。
pip install torch torchvision torchattacks
Step 1: 建立基线模型(Standard Training)
首先,定义一个简单的卷积神经网络(CNN)模型,并用标准方法训练作为基线。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchattacks import FGSM
# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. 定义模型
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.relu = nn.ReLU()
self.fc = nn.Linear(5 * 5 * 32, 10) # 假设输入是28x28
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = nn.MaxPool2d(2)(x)
x = nn.Flatten()(x)
x = self.fc(x)
return x
def load_mnist_data():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True, transform=transform),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transform),
batch_size=1000, shuffle=False)
return train_loader, test_loader
# 训练函数(标准)
def train_standard(model, data_loader, epochs=5):
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
print("\n--- Standard Model Training ---")
for epoch in range(epochs):
for data, target in data_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1} done.")
return model
train_loader, test_loader = load_mnist_data()
standard_model = SimpleCNN()
# standard_model = train_standard(standard_model, train_loader) # 假设已经训练并保存
# torch.save(standard_model.state_dict(), 'standard_model.pth')
Step 2: 实现防御蒸馏
防御蒸馏分为两步:教师模型训练和学生模型训练(蒸馏)。关键在于使用高温$T$来平滑Softmax输出。
$$\text{Softmax}_{T}(z_i) = \frac{e^{z_i/T}}{\sum_j e^{z_j/T}}$$
我们将温度$T$设置为$20$。
A. 教师模型训练
教师模型使用高温$T$下的Softmax进行训练,但其损失函数仍基于硬标签(即标准的Cross-Entropy Loss)。
TEMPERATURE = 20.0
# 训练函数(教师模型 - 使用高温LogSoftmax)
def train_teacher(model, data_loader, epochs=5):
model.to(device)
# 使用LogSoftmax是为了数值稳定性,结合NLLLoss等价于CrossEntropyLoss
criterion = lambda output, target: nn.NLLLoss()(nn.LogSoftmax(dim=1)(output / TEMPERATURE), target)
optimizer = optim.Adam(model.parameters())
print("\n--- Teacher Model Training (High T) ---")
# ... (训练循环与标准训练相似,省略细节)
return model
teacher_model = SimpleCNN()
# teacher_model = train_teacher(teacher_model, train_loader) # 假设已经训练并保存
# torch.save(teacher_model.state_dict(), 'teacher_model.pth')
B. 提取软标签并训练学生模型(蒸馏)
学生模型使用教师模型的软标签作为训练目标。我们使用KL散度(Kullback-Leibler Divergence)作为损失函数,来衡量学生模型的输出分布与教师模型软标签分布的相似度。学生模型的最终输出在推理时必须恢复到$T=1$。
$$L_{distill} = ext{KL}(P_{Teacher}(T) || P_{Student}(T))$$
# 2. 蒸馏训练函数 (学生模型)
def train_student_distillation(student_model, teacher_model, data_loader, epochs=5):
student_model.to(device)
teacher_model.to(device)
teacher_model.eval() # 教师模型锁定,只用于生成软标签
# 目标损失:KL散度 (学生模型的log_softmax和教师模型的softmax)
criterion = nn.KLDivLoss(reduction='batchmean')
optimizer = optim.Adam(student_model.parameters())
print("\n--- Student Model Distillation (Defensive DD) ---")
for epoch in range(epochs):
for data, _ in data_loader:
data = data.to(device)
# 1. 教师模型生成软标签(注意:需要使用标准Softmax,非Log)
with torch.no_grad():
teacher_output = teacher_model(data)
soft_labels = nn.Softmax(dim=1)(teacher_output / TEMPERATURE)
# 2. 学生模型输出 (使用高温LogSoftmax)
student_output = student_model(data)
log_probs = nn.LogSoftmax(dim=1)(student_output / TEMPERATURE)
# 3. 计算蒸馏损失
# 注意:KLDivLoss期望第一个输入是Log概率,第二个是概率
loss = criterion(log_probs, soft_labels) * (TEMPERATURE**2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Distill Epoch {epoch+1} done.")
return student_model
# 假设teacher_model已加载,开始蒸馏训练
distilled_model = SimpleCNN()
# distilled_model = train_student_distillation(distilled_model, teacher_model, train_loader)
# torch.save(distilled_model.state_dict(), 'distilled_model.pth')
Step 3: 评估防御效果(Adversarial Evaluation)
防御蒸馏的效果必须通过对抗性攻击的成功率来衡量。我们将使用快速梯度符号法(FGSM)进行攻击,并比较基线模型和蒸馏模型在被攻击后的准确率。
关键指标: 被攻击准确率 (Adversarial Accuracy)。
# 加载预训练或已训练的模型状态
# standard_model.load_state_dict(torch.load('standard_model.pth'))
# distilled_model.load_state_dict(torch.load('distilled_model.pth'))
def evaluate_robustness(model, test_loader, attack_name, eps):
model.eval()
correct = 0
total = 0
# 初始化攻击
if attack_name == 'FGSM':
attacker = FGSM(model, eps=eps) # 注意: torchattacks会自动处理模型的设备和训练模式
else:
raise NotImplementedError("Attack not implemented")
for data, target in test_loader:
data, target = data.to(device), target.to(device)
# 生成对抗样本
adv_data = attacker(data, target)
# 评估模型对对抗样本的预测
output = model(adv_data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += len(data)
adv_accuracy = 100. * correct / total
print(f"Robustness ({attack_name}, eps={eps}): {adv_accuracy:.2f}%")
return adv_accuracy
EPSILON = 0.3 # 设定扰动强度
# 运行评估
print("\n=== Robustness Evaluation ===")
# evaluate_robustness(standard_model, test_loader, 'FGSM', EPSILON)
# evaluate_robustness(distilled_model, test_loader, 'FGSM', EPSILON)
# 预期结果示例 (在实际训练后)
# Standard Model Robustness (FGSM, eps=0.3): 5.12%
# Distilled Model Robustness (FGSM, eps=0.3): 65.45%
结果分析与总结
通过实际操作,我们会观察到:
- 干净准确率(Clean Accuracy): 蒸馏模型(DD)和基线模型在未受攻击的数据集上的准确率通常非常接近(都高于98%)。
- 对抗准确率(Adversarial Accuracy): 在中等强度的FGSM攻击下(如$\epsilon=0.3$),基线模型的准确率会急剧下降到个位数,而防御蒸馏模型能维持相对较高的准确率(通常高于50%甚至更高)。
这证明防御蒸馏有效地平滑了决策边界,使得对抗性扰动产生的梯度影响被稀释,从而显著提高了模型的对抗鲁棒性。然而,需要注意的是,随着更复杂的攻击方法(如C&W或PGA)的出现,纯粹的防御蒸馏可能不足以提供完全的保护,但它仍然是构建鲁棒性模型的重要基石。
汤不热吧