Skip to content

demo中剪枝后预测结果差距很大? #8

@xn1997

Description

@xn1997

我在prune_by_class.py程序中测试了一下剪枝前后模型预测结果的差别,发现结果差距很大,这种现象正常吗,还是说我的理解有问题。代码如下

import sys

sys.path.append("..")
import torch
import torchpruner
import torchvision
import numpy as np

# 以下代码示例了对每一个BN层去除其weight系数绝对值前20%小的层
inputs_sample = torch.ones(1, 3, 224, 224).to('cpu')
# 加载模型
model = torchvision.models.vgg11_bn()
result_source = model(inputs_sample)
# 创建ONNXGraph对象,绑定需要被剪枝的模型
graph = torchpruner.ONNXGraph(model)
##build ONNX静态图结构,需要指定输入的张量
graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))

# 遍历所有的Module
for key in list(graph.modules):
    module = graph.modules[key]
    # 如果该module对应了BN层
    if isinstance(module.nn_object, torch.nn.BatchNorm2d):
        # 获取该对象
        nn_object = module.nn_object
        # 排序,取前20%小的权重值对应的index
        weight = nn_object.weight.detach().cpu().numpy()
        index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.02)]
        result = module.cut_analysis("weight", index=index, dim=0)
        model, context = torchpruner.set_cut(model, result)
        if context:
            # graph 存放了各层参数和输出张量的 numpy.ndarray 版本,需要更新
            graph = torchpruner.ONNXGraph(model)  # 也可以不重新创建 graph
            graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))

# 新的model即为剪枝后的模型
print(model)

result_prune = model(inputs_sample)
print(f"剪枝前结果:{result_source.sum()}")
print(f"剪枝后结果:{result_prune.sum()}")
print(f"数据差距{(abs(result_source-result_prune)).sum()}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions