-
Notifications
You must be signed in to change notification settings - Fork 41
Open
Description
我在prune_by_class.py
程序中测试了一下剪枝前后模型预测结果的差别,发现结果差距很大,这种现象正常吗,还是说我的理解有问题。代码如下
import sys
sys.path.append("..")
import torch
import torchpruner
import torchvision
import numpy as np
# 以下代码示例了对每一个BN层去除其weight系数绝对值前20%小的层
inputs_sample = torch.ones(1, 3, 224, 224).to('cpu')
# 加载模型
model = torchvision.models.vgg11_bn()
result_source = model(inputs_sample)
# 创建ONNXGraph对象,绑定需要被剪枝的模型
graph = torchpruner.ONNXGraph(model)
##build ONNX静态图结构,需要指定输入的张量
graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))
# 遍历所有的Module
for key in list(graph.modules):
module = graph.modules[key]
# 如果该module对应了BN层
if isinstance(module.nn_object, torch.nn.BatchNorm2d):
# 获取该对象
nn_object = module.nn_object
# 排序,取前20%小的权重值对应的index
weight = nn_object.weight.detach().cpu().numpy()
index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.02)]
result = module.cut_analysis("weight", index=index, dim=0)
model, context = torchpruner.set_cut(model, result)
if context:
# graph 存放了各层参数和输出张量的 numpy.ndarray 版本,需要更新
graph = torchpruner.ONNXGraph(model) # 也可以不重新创建 graph
graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))
# 新的model即为剪枝后的模型
print(model)
result_prune = model(inputs_sample)
print(f"剪枝前结果:{result_source.sum()}")
print(f"剪枝后结果:{result_prune.sum()}")
print(f"数据差距{(abs(result_source-result_prune)).sum()}")
Metadata
Metadata
Assignees
Labels
No labels