模型剪枝(Pruning)和量化(Quantization)是AI基础设施优化的两大核心手段。它们能显著减少模型的内存占用和计算复杂度,尤其对于边缘设备和高并发推理服务至关重要。虽然早期有像Distiller这样的专用工具包,但在现代PyTorch生态中,我们通常结合使用PyTorch原生API和一些先进的优化库(如Intel Neural Compressor或NVIDIA NNCF)来实现这一目标。
本文将聚焦于如何使用PyTorch内置的torch.nn.utils.prune模块进行非结构化Magnitude Pruning,并结合PyTorch的静态后训练量化(PTQ)流程,实现模型压缩。
Contents
1. 环境准备
确保你安装了最新的PyTorch版本,因为它对量化和剪枝的支持更加完善。
1 pip install torch torchvision
2. 基于幅值(Magnitude)的非结构化剪枝
Magnitude Pruning是最简单也是最有效的剪枝方法之一。它通过移除权重绝对值最小的神经连接来实现稀疏化。PyTorch的torch.nn.utils.prune模块提供了强大的API来管理剪枝的生命周期。
2.1. 定义模型和剪枝函数
我们以一个简单的全连接网络为例,展示如何检查稀疏度并执行剪枝。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 import torch
import torch.nn as nn
from torch.nn.utils import prune
# 定义一个简单的模型
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 784)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleNet()
# 辅助函数:检查模型中的稀疏度
def check_sparsity(model):
total_weights = 0
zero_weights = 0
for name, module in model.named_modules():
if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor):
total_weights += module.weight.numel()
# 注意:如果使用了prune模块,需要检查module.weight_orig
if hasattr(module, 'weight_mask'):
zero_weights += torch.sum(module.weight_mask == 0).item()
else:
# 对于未剪枝的模型,直接检查0值
zero_weights += torch.sum(module.weight == 0).item()
return zero_weights / total_weights if total_weights > 0 else 0
print(f"初始模型稀疏度: {check_sparsity(model) * 100:.2f}%")
2.2. 执行剪枝操作
我们将对fc1层执行30%的非结构化(Unstructured)剪枝。这意味着我们将移除该层中绝对值最小的30%的权重。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17 # 设定剪枝比例
pruning_amount = 0.3
# 1. 执行L1非结构化剪枝:基于权重的L1范数(绝对值)最小进行裁剪
prune.l1_unstructured(model.fc1, name='weight', amount=pruning_amount)
# 2. 检查剪枝后的稀疏度 (此时权重仍然存在,但mask已生效)
print(f"剪枝后模型稀疏度: {check_sparsity(model) * 100:.2f}%")
# 3. 永久移除剪枝参数化(将稀疏权重写入原始权重张量,并移除hook)
# 这一步是模型导出的关键,确保推理时不依赖mask
prune.remove(model.fc1, 'weight')
# 检查移除后的稀疏度,确认权重已经固化
print(f"永久移除后fc1稀疏度: {torch.sum(model.fc1.weight == 0).item() / model.fc1.weight.numel() * 100:.2f}%")
# 注意:在实际应用中,剪枝后通常需要进行微调(Fine-tuning)以恢复精度。
3. Post-Training Static Quantization (PTQ)
剪枝减少了参数数量,而量化(通常是FP32到INT8)减少了每个参数所需的比特数,同时启用特定的加速硬件指令。我们使用PyTorch的FX Graph Mode Quantization API进行后训练静态量化。
3.1. 校准与准备
PTQ需要一小批代表性数据(校准集)来计算权重的统计信息(如最小值、最大值)和激活的分布,以便确定最佳的量化尺度(Scale)和零点(Zero-Point)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 import torch.quantization
# 假设我们已经完成了剪枝后的模型微调
model.eval()
# 准备校准数据 (使用随机数据代替实际DataLoader)
# 注意:数据的形状必须符合模型的输入要求 (这里是1x784)
calibration_data = [torch.randn(1, 784) for _ in range(20)]
# 1. 配置量化后端和模式
# 'fbgemm' 适用于服务器CPU (x86), 'qnnpack' 适用于移动端/ARM CPU
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# 2. 准备阶段:插入观察者(Observer)
# 这一步会在模型中插入模块,用于收集激活值范围的统计信息
quantized_model_prepared = torch.quantization.prepare(model, inplace=False)
# 3. 校准过程
print("开始校准...")
with torch.inference_mode():
for data in calibration_data:
quantized_model_prepared(data)
# 4. 转换阶段:将观察者转换为实际的量化模块
quantized_model = torch.quantization.convert(quantized_model_prepared, inplace=False)
print("量化完成。")
# 5. 验证效果 (检查参数类型)
print(f"原模型fc1的权重类型: {model.fc1.weight.dtype}")
# 量化模型中,如果转换成功,参数会被包装在QuantizedLinear模块中
print(f"量化模型fc1的类型: {type(quantized_model.fc1)}")
# 恭喜,你已经成功创建了一个剪枝且量化的INT8模型!
4. 总结与部署效益
通过上述步骤,我们首先使用剪枝API减小了模型文件大小和浮点运算量(FLOPs),然后通过静态量化将参数精度从FP32降至INT8。这两个步骤通常能带来2-4倍的推理加速,并极大地降低模型部署所需的内存和带宽。
在部署时,可以将这个quantized_model导出为ONNX格式(注意:导出ONNX时需要使用特定的量化导出路径,以确保量化信息被正确保留,例如使用ONNX Runtime的量化支持)或TorchScript格式,以在生产环境中运行。
汤不热吧