在AI模型部署到真实世界场景中时,一个关键的挑战是模型对于“自然损坏”(Natural Corruptions)的抵抗能力。这些损坏包括雾、雪、亮度变化、数字噪声等。ImageNet-C(ImageNet-Corrupted)基准是量化模型鲁棒性的黄金标准。本文将深入介绍如何设置环境,并使用PyTorch实现ImageNet-C的完整评估流程,计算出关键指标:平均损坏误差(mCE)。
Contents
1. ImageNet-C概览与核心指标
ImageNet-C包含15种类型的损坏,每种类型又分为5个不同的严重程度(Severity Levels,1到5)。这意味着我们需要评估模型在 $15 \times 5 = 75$ 个不同的子数据集上的性能。
核心指标:平均损坏误差 (mCE)
mCE衡量的是模型在所有损坏类型和严重程度上的平均性能下降。为了计算mCE,我们需要先计算每个子数据集上的相对损坏误差 (CE):
$$\text{CE} = \frac{E_{corruption} – E_{clean}}{E_{baseline} – E_{clean}}$$
其中:
* $E_{corruption}$ 是模型在特定损坏子集上的错误率。
* $E_{clean}$ 是模型在标准ImageNet验证集上的错误率(作为比较对象)。
* $E_{baseline}$ 是一个基准模型(通常是未优化的AlexNet)在标准ImageNet验证集上的错误率。
最终的mCE则是所有15种损坏类型、5个严重等级的CE的平均值。
2. 环境设置与数据准备
假设您已经部署了PyTorch环境。我们需要安装timm库来方便地加载预训练模型,以及torchvision进行数据处理。
2.1 数据结构
ImageNet-C的数据通常需要手动下载,并组织成如下结构:
1
2
3
4
5
6
7
8
9
10 /imagenet-c/
|-- fog/
| |-- 1/ (Severity 1)
| |-- 2/
| |-- ...
| |-- 5/
|-- gaussian_noise/
| |-- 1/
| |-- ...
|-- ... (15 types total)
每个叶子目录(如fog/1/)内包含标准的ImageNet验证集目录结构(例如n01440764/, n01443537/等)。
3. 实现鲁棒性评估脚本
以下是一个使用PyTorch实现ImageNet-C评估的核心流程。
3.1 必要的配置和工具函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32 import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
# --- Configuration ---
IMAGENET_C_ROOT = '/path/to/imagenet-c/'
BATCH_SIZE = 128
NUM_WORKERS = 4
CORRUPTIONS = [
'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression'
]
SEVERITIES = [1, 2, 3, 4, 5]
# ImageNet 预处理标准
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Helper function to calculate error rate
def get_error_rate(output, target):
_, pred = output.topk(1, 1, True, True)
return 1.0 - (pred.t() == target).float().mean().item()
3.2 核心评估函数
我们需要遍历所有 75 个子集,计算各自的错误率。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48 def evaluate_corruption(model, data_root, corruption_name, severity):
# 1. 设置特定损坏子集的路径
corruption_dir = os.path.join(data_root, corruption_name, str(severity))
# 2. 检查路径是否有效
if not os.path.isdir(corruption_dir):
print(f"Warning: Path not found: {corruption_dir}")
return None
# 3. 加载数据集
dataset = datasets.ImageFolder(corruption_dir, transform=transform)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True
)
model.eval()
total_error = 0.0
total_samples = 0
with torch.no_grad():
for input, target in loader:
input = input.cuda()
target = target.cuda()
output = model(input)
total_error += get_error_rate(output, target) * len(input)
total_samples += len(input)
error_rate = total_error / total_samples
print(f"-> {corruption_name} (Sev {severity}): Error Rate = {error_rate:.4f}")
return error_rate
def run_imagenet_c_benchmark(model, root_dir):
all_corruption_errors = {}
for c_name in CORRUPTIONS:
all_corruption_errors[c_name] = []
for severity in SEVERITIES:
err = evaluate_corruption(model, root_dir, c_name, severity)
if err is not None:
all_corruption_errors[c_name].append(err)
return all_corruption_errors
3.3 运行和计算 mCE
为了演示,我们使用一个预训练的ResNet-50模型。这里需要假设我们已经得到了干净错误率 $E_{clean}$ 和 AlexNet 的基准错误率 $E_{baseline}$ (通常取 $E_{baseline} \approx 0.448$ 错误率,即 $5.2\%$ 准确率)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29 # 假设基准值 (使用PyTorch官方ResNet-50和AlexNet在ImageNet上的结果)
E_clean_model = 0.237 # 假设ResNet-50的干净错误率
E_baseline_alexnet = 0.448 # AlexNet (用于标准化)
# 1. 加载模型
model_name = 'resnet50'
model = timm.create_model(model_name, pretrained=True).cuda()
# 2. 运行评估
print(f"Starting ImageNet-C evaluation for {model_name}...")
results = run_imagenet_c_benchmark(model, IMAGENET_C_ROOT)
# 3. 计算最终 mCE
all_ce_values = []
for c_name, errors in results.items():
if not errors: continue
# 计算特定损坏类型的平均 CE (通常是5个严重等级的平均)
# 注意:更标准的mCE计算是计算所有 75 个子集的平均值。
for E_corruption in errors:
CE = (E_corruption - E_clean_model) / (E_baseline_alexnet - E_clean_model)
all_ce_values.append(CE)
mCE = sum(all_ce_values) / len(all_ce_values)
print("\n--- Results Summary ---")
print(f"Total evaluated subsets: {len(all_ce_values)}")
print(f"Mean Corruption Error (mCE): {mCE:.4f}")
通过这个流程,AI基础设施工程师可以系统地量化部署模型的鲁棒性瓶颈,指导模型优化(例如使用数据增强或鲁棒性训练方法,如AugMix)的方向。
汤不热吧