Skip to content

resnet50剪枝报错 #24

@Wq-dd

Description

@Wq-dd

你好,我在使用resnet18为主干网的retinanet时,自己使用稀疏训练后的模型剪枝会报错,我的做法是:

  1. 首先将训练好的模型计算bn的阈值得到每个bn层应该要剪枝的索引,并保存到一个dict里。
  2. 然后循环1中的dict使用torchprunner去剪枝,会遇到前面的某些层如果剪了过多通道,后面层再剪时会出现索引越界。
  3. 下面是我的部分代码。
        import torchpruner 
        # 创建ONNXGraph对象,绑定需要被剪枝的模型
        self.model.eval()
        graph = torchpruner.ONNXGraph(self.model.cpu())
        ##build ONNX静态图结构,需要指定输入的张量
        graph.build_graph(inputs=(torch.zeros(1, 3, 640, 640),))
        for i, (k, v) in enumerate(mask_dict_for_pruner.items()):
        # 获取conv1模块对应的module
            conv1_module = graph.modules[k]

            # 对前四个通道进行剪枝分析,指定对weight权重进行剪枝,剪枝前四个通道
            # weight权重out_channels对应的通道维度为0
            result = conv1_module.cut_analysis(attribute_name="weight", index=v, dim=0)

            # 剪枝执行模块执行剪枝操作,对模型完成剪枝过程.context变量提供了用于剪枝恢复的上下文
            self.model, context = torchpruner.set_cut(self.model, result)
        # 新的model即为剪枝后的模型
        print(self.model)```

请问是我的用法不对吗还是说这种先计算剪枝的索引再调用torchpruner的方法不对呢

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions