欢迎光临
我们一直在努力

怎样实现并评估防御蒸馏(Defensive Distillation)的防御效果?

深入理解防御蒸馏(Defensive Distillation)

防御蒸馏(Defensive Distillation, DD)是一种旨在提高深度学习模型对对抗性攻击(Adversarial Attacks)鲁棒性的技术。它由Papernot等人在2016年提出,核心思想是利用知识蒸馏(Knowledge Distillation)的方法,通过平滑模型的决策边界来减少对抗性扰动的有效性。

在模型部署环境中,即使是很小的、人类无法察觉的扰动,也可能导致模型输出错误的结果。DD通过训练一个“学生”模型去模仿“教师”模型在高温(High Temperature, T > 1)下生成的软标签(Soft Labels),从而使模型的梯度幅度变小,决策表面更加平滑,攻击者难以找到有效的扰动方向。

实验环境准备

我们将使用PyTorch和MNIST数据集,并引入torchattacks库来生成对抗样本。

pip install torch torchvision torchattacks

Step 1: 建立基线模型(Standard Training)

首先,定义一个简单的卷积神经网络(CNN)模型,并用标准方法训练作为基线。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchattacks import FGSM

# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. 定义模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(5 * 5 * 32, 10) # 假设输入是28x28

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = nn.MaxPool2d(2)(x)
        x = nn.Flatten()(x)
        x = self.fc(x)
        return x

def load_mnist_data():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True, transform=transform),
        batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transform),
        batch_size=1000, shuffle=False)
    return train_loader, test_loader

# 训练函数(标准)
def train_standard(model, data_loader, epochs=5):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters())
    print("\n--- Standard Model Training ---")
    for epoch in range(epochs):
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1} done.")
    return model

train_loader, test_loader = load_mnist_data()
standard_model = SimpleCNN()
# standard_model = train_standard(standard_model, train_loader) # 假设已经训练并保存
# torch.save(standard_model.state_dict(), 'standard_model.pth')

Step 2: 实现防御蒸馏

防御蒸馏分为两步:教师模型训练和学生模型训练(蒸馏)。关键在于使用高温$T$来平滑Softmax输出。

$$\text{Softmax}_{T}(z_i) = \frac{e^{z_i/T}}{\sum_j e^{z_j/T}}$$

我们将温度$T$设置为$20$。

A. 教师模型训练

教师模型使用高温$T$下的Softmax进行训练,但其损失函数仍基于硬标签(即标准的Cross-Entropy Loss)。

TEMPERATURE = 20.0

# 训练函数(教师模型 - 使用高温LogSoftmax)
def train_teacher(model, data_loader, epochs=5):
    model.to(device)
    # 使用LogSoftmax是为了数值稳定性,结合NLLLoss等价于CrossEntropyLoss
    criterion = lambda output, target: nn.NLLLoss()(nn.LogSoftmax(dim=1)(output / TEMPERATURE), target)
    optimizer = optim.Adam(model.parameters())
    print("\n--- Teacher Model Training (High T) ---")
    # ... (训练循环与标准训练相似,省略细节)
    return model

teacher_model = SimpleCNN()
# teacher_model = train_teacher(teacher_model, train_loader) # 假设已经训练并保存
# torch.save(teacher_model.state_dict(), 'teacher_model.pth')

B. 提取软标签并训练学生模型(蒸馏)

学生模型使用教师模型的软标签作为训练目标。我们使用KL散度(Kullback-Leibler Divergence)作为损失函数,来衡量学生模型的输出分布与教师模型软标签分布的相似度。学生模型的最终输出在推理时必须恢复到$T=1$。

$$L_{distill} = ext{KL}(P_{Teacher}(T) || P_{Student}(T))$$

# 2. 蒸馏训练函数 (学生模型)
def train_student_distillation(student_model, teacher_model, data_loader, epochs=5):
    student_model.to(device)
    teacher_model.to(device)
    teacher_model.eval() # 教师模型锁定,只用于生成软标签

    # 目标损失:KL散度 (学生模型的log_softmax和教师模型的softmax)
    criterion = nn.KLDivLoss(reduction='batchmean')
    optimizer = optim.Adam(student_model.parameters())

    print("\n--- Student Model Distillation (Defensive DD) ---")
    for epoch in range(epochs):
        for data, _ in data_loader:
            data = data.to(device)

            # 1. 教师模型生成软标签(注意:需要使用标准Softmax,非Log)
            with torch.no_grad():
                teacher_output = teacher_model(data)
                soft_labels = nn.Softmax(dim=1)(teacher_output / TEMPERATURE)

            # 2. 学生模型输出 (使用高温LogSoftmax)
            student_output = student_model(data)
            log_probs = nn.LogSoftmax(dim=1)(student_output / TEMPERATURE)

            # 3. 计算蒸馏损失
            # 注意:KLDivLoss期望第一个输入是Log概率,第二个是概率
            loss = criterion(log_probs, soft_labels) * (TEMPERATURE**2)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Distill Epoch {epoch+1} done.")
    return student_model

# 假设teacher_model已加载,开始蒸馏训练
distilled_model = SimpleCNN()
# distilled_model = train_student_distillation(distilled_model, teacher_model, train_loader)
# torch.save(distilled_model.state_dict(), 'distilled_model.pth')

Step 3: 评估防御效果(Adversarial Evaluation)

防御蒸馏的效果必须通过对抗性攻击的成功率来衡量。我们将使用快速梯度符号法(FGSM)进行攻击,并比较基线模型和蒸馏模型在被攻击后的准确率。

关键指标: 被攻击准确率 (Adversarial Accuracy)。

# 加载预训练或已训练的模型状态
# standard_model.load_state_dict(torch.load('standard_model.pth'))
# distilled_model.load_state_dict(torch.load('distilled_model.pth'))

def evaluate_robustness(model, test_loader, attack_name, eps):
    model.eval()
    correct = 0
    total = 0

    # 初始化攻击
    if attack_name == 'FGSM':
        attacker = FGSM(model, eps=eps) # 注意: torchattacks会自动处理模型的设备和训练模式
    else:
        raise NotImplementedError("Attack not implemented")

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)

        # 生成对抗样本
        adv_data = attacker(data, target)

        # 评估模型对对抗样本的预测
        output = model(adv_data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        total += len(data)

    adv_accuracy = 100. * correct / total
    print(f"Robustness ({attack_name}, eps={eps}): {adv_accuracy:.2f}%")
    return adv_accuracy

EPSILON = 0.3 # 设定扰动强度

# 运行评估
print("\n=== Robustness Evaluation ===")
# evaluate_robustness(standard_model, test_loader, 'FGSM', EPSILON)
# evaluate_robustness(distilled_model, test_loader, 'FGSM', EPSILON)

# 预期结果示例 (在实际训练后)
# Standard Model Robustness (FGSM, eps=0.3): 5.12%
# Distilled Model Robustness (FGSM, eps=0.3): 65.45%

结果分析与总结

通过实际操作,我们会观察到:

  1. 干净准确率(Clean Accuracy): 蒸馏模型(DD)和基线模型在未受攻击的数据集上的准确率通常非常接近(都高于98%)。
  2. 对抗准确率(Adversarial Accuracy): 在中等强度的FGSM攻击下(如$\epsilon=0.3$),基线模型的准确率会急剧下降到个位数,而防御蒸馏模型能维持相对较高的准确率(通常高于50%甚至更高)。

这证明防御蒸馏有效地平滑了决策边界,使得对抗性扰动产生的梯度影响被稀释,从而显著提高了模型的对抗鲁棒性。然而,需要注意的是,随着更复杂的攻击方法(如C&W或PGA)的出现,纯粹的防御蒸馏可能不足以提供完全的保护,但它仍然是构建鲁棒性模型的重要基石。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 怎样实现并评估防御蒸馏(Defensive Distillation)的防御效果?
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址