设为首页收藏本站
天天打卡

 找回密码
 立即注册
搜索
查看: 151|回复: 14

PyTorch实现模型剪枝的方法

[复制链接]
  • 打卡等级:无名新人
  • 打卡总天数:1
  • 打卡月天数:0
  • 打卡总奖励:13
  • 最近打卡:2024-05-10 17:32:55

4

主题

53

回帖

208

积分

中级会员

积分
208
发表于 2024-4-20 09:42:00 | 显示全部楼层 |阅读模式
目录


指南概述

在这篇文章中,我将向你介绍如何在PyTorch中实现模型剪枝。剪枝是一种优化模型的技术,可以帮助减少模型的大小和计算量,同时保持模型的准确性。我将为你提供一个详细的步骤指南,并指导你如何在每个步骤中使用适当的PyTorch代码。

整体流程

下面是实现PyTorch剪枝的整体流程,我们将按照这些步骤逐步进行操作:
步骤操作1.加载预训练模型2.定义剪枝算法3.执行剪枝操作4.重新训练和微调模型5.评估剪枝后的模型性能
步骤详解


步骤1:加载预训练模型

首先,我们需要加载一个预训练的模型作为我们的基础模型。在这里,我们以ResNet18为例。
  1. import torch
  2. import torchvision.models as models

  3. # 加载预训练的ResNet18模型
  4. model = models.resnet18(pretrained=True)
复制代码
步骤2:定义剪枝算法

接下来,我们需要定义一个剪枝算法,这里我们以Global Magnitude Pruning(全局幅度剪枝)为例。
  1. from torch.nn.utils.prune import global_unstructured

  2. # 定义剪枝比例
  3. pruning_rate = 0.5

  4. # 对模型的全连接层进行剪枝
  5. def prune_model(model, pruning_rate):
  6.     for name, module in model.named_modules():
  7.         if isinstance(module, torch.nn.Linear):
  8.             global_unstructured(module, pruning_dim=0, amount=pruning_rate)
复制代码
步骤3:执行剪枝操作

现在,我们可以执行剪枝操作,并查看剪枝后的模型结构。
  1. prune_model(model, pruning_rate)

  2. # 查看剪枝后的模型结构
  3. print(model)
复制代码
步骤4:重新训练和微调模型

剪枝后的模型需要重新进行训练和微调,以保证模型的准确性和性能。
  1. # 定义损失函数和优化器
  2. criterion = torch.nn.CrossEntropyLoss()
  3. optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

  4. # 重新训练和微调模型
  5. # 省略训练代码
复制代码
步骤5:评估剪枝后的模型性能

最后,我们需要对剪枝后的模型进行评估,以比较剪枝前后的性能差异。
  1. # 评估剪枝后的模型
  2. # 省略评估代码
复制代码
补:PyTorch中实现的剪枝方式有三种:

  • 局部剪枝
  • 全局剪枝
  • 自定义剪枝
局部剪枝
局部剪枝实验,假定对模型的第一个卷积层中的权重进行剪枝
  1. model_1 = LeNet()
  2. module = model_1.conv1
  3. # 剪枝前
  4. print(list(module.named_parameters()))
  5. print(list(module.named_buffers()))
  6. prune.random_unstructured(module, name="weight", amount=0.3)
  7. # 剪枝后
  8. print(list(module.named_parameters()))
  9. print(list(module.named_buffers()))
复制代码
运行结果
  1. ## 剪枝前[('weight', Parameter containing:tensor([[[[ 0.1729, -0.0109, -0.1399],          [ 0.1019,  0.1883,  0.0054],          [-0.0790, -0.1790, -0.0792]]],                ...
  2.         [[[ 0.2465,  0.2114,  0.3208],          [-0.2067, -0.2097, -0.0431],          [ 0.3005, -0.2022,  0.1341]]]], requires_grad=True)), ('bias', Parameter containing:tensor([-0.1437,  0.0605,  0.1427, -0.3111, -0.2476,  0.1901],       requires_grad=True))][]
  3. ## 剪枝后[('bias', Parameter containing:tensor([-0.1437,  0.0605,  0.1427, -0.3111, -0.2476,  0.1901],       requires_grad=True)), ('weight_orig', Parameter containing:tensor([[[[ 0.1729, -0.0109, -0.1399],          [ 0.1019,  0.1883,  0.0054],          [-0.0790, -0.1790, -0.0792]]],
  4.         ...
  5.         [[[ 0.2465,  0.2114,  0.3208],          [-0.2067, -0.2097, -0.0431],          [ 0.3005, -0.2022,  0.1341]]]], requires_grad=True))]
  6. [('weight_mask', tensor([[[[1., 1., 1.],          [1., 1., 1.],          [1., 1., 1.]]],
  7.         [[[0., 1., 0.],          [0., 1., 1.],          [1., 0., 1.]]],
  8.         [[[0., 1., 1.],          [1., 0., 1.],          [1., 0., 1.]]],
  9.         [[[1., 1., 1.],          [1., 0., 1.],          [0., 1., 0.]]],
  10.         [[[0., 0., 1.],          [0., 1., 1.],          [1., 1., 1.]]],
  11.         [[[0., 1., 1.],          [0., 1., 0.],          [1., 1., 1.]]]]))]
复制代码
模型经历剪枝操作后, 原始的权重矩阵weight参数不见了,变成了weight_orig。 并且剪枝前打印为空列表的
  1. module.named_buffers()
复制代码
,此时拥有了一个weight_mask参数。经过剪枝操作后的模型,原始的参数存放在了weight_orig中,对应的剪枝矩阵存放在weight_mask中, 而将weight_mask视作掩码张量,再和weight_orig相乘的结果就存放在了weight中。
全局剪枝
局部剪枝只能以部分网络模块为单位进行剪枝,更广泛的剪枝策略是采用全局剪枝(global pruning),比如在整体网络的视角下剪枝掉20%的权重参数,而不是在每一层上都剪枝掉20%的权重参数。采用全局剪枝后,不同的层被剪掉的百分比不同。
  1. model_2 = LeNet().to(device=device)

  2. # 首先打印初始化模型的状态字典
  3. print(model_2.state_dict().keys())

  4. # 构建参数集合, 决定哪些层, 哪些参数集合参与剪枝
  5. parameters_to_prune = (
  6.             (model_2.conv1, 'weight'),
  7.             (model_2.conv2, 'weight'),
  8.             (model_2.fc1, 'weight'),
  9.             (model_2.fc2, 'weight'),
  10.             (model_2.fc3, 'weight'))
  11. # 调用prune中的全局剪枝函数global_unstructured执行剪枝操作, 此处针对整体模型中的20%参数量进行剪枝
  12. prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)

  13. # 最后打印剪枝后的模型的状态字典
  14. print(model_2.state_dict().keys())
复制代码
输出结果
  1. odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'fc1.bias', 'fc1.weight_orig', 'fc1.weight_mask', 'fc2.bias', 'fc2.weight_orig', 'fc2.weight_mask', 'fc3.bias', 'fc3.weight_orig', 'fc3.weight_mask'])
复制代码
当采用全局剪枝策略的时候(假定20%比例参数参与剪枝),仅保证模型总体参数量的20%被剪枝掉,具体到每一层的情况则由模型的具体参数分布情况来定。
自定义剪枝
自定义剪枝可以自定义一个子类,用来实现具体的剪枝逻辑,比如对权重矩阵进行间隔性的剪枝
  1. class my_pruning_method(prune.BasePruningMethod):
  2.     PRUNING_TYPE = "unstructured"
  3.    
  4.     def compute_mask(self, t, default_mask):
  5.         mask = default_mask.clone()
  6.         mask.view(-1)[::2] = 0
  7.         return mask
  8.    
  9. def my_unstructured_pruning(module, name):
  10.     my_pruning_method.apply(module, name)
  11.     return module

  12. model_3 = LeNet()
  13. print(model_3)
复制代码
在剪枝前查看网络结构
  1. LeNet(
  2.   (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  3.   (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  4.   (fc1): Linear(in_features=400, out_features=120, bias=True)
  5.   (fc2): Linear(in_features=120, out_features=84, bias=True)
  6.   (fc3): Linear(in_features=84, out_features=10, bias=True)
  7. )
复制代码
采用自定义剪枝的方式对局部模块fc3进行剪枝
  1. my_unstructured_pruning(model.fc3, name="bias")
  2. print(model.fc3.bias_mask)
复制代码
输出结果
  1. tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])
复制代码
最后的剪枝效果与实现的逻辑一致。

总结

通过上面的步骤指南和代码示例,相信你可以学会如何在PyTorch中实现模型剪枝。剪枝是一个有效的模型优化技术,可以帮助你构建更加高效和精确的深度学习模型。
到此这篇关于PyTorch实现模型剪枝的方法的文章就介绍到这了,更多相关PyTorch 剪枝内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
  • 打卡等级:无名新人
  • 打卡总天数:1
  • 打卡月天数:0
  • 打卡总奖励:30
  • 最近打卡:2024-04-26 21:46:17

2

主题

67

回帖

209

积分

中级会员

积分
209
发表于 2024-4-21 21:21:07 | 显示全部楼层
这个话题很有趣,我想多了解一些

0

主题

40

回帖

80

积分

注册会员

积分
80
发表于 2024-5-4 16:57:39 | 显示全部楼层
我完全同意你的观点

0

主题

47

回帖

95

积分

注册会员

积分
95
发表于 2024-5-10 03:00:50 | 显示全部楼层
这个话题很有趣,我想多了解一些

0

主题

40

回帖

80

积分

注册会员

积分
80
发表于 2024-5-17 13:03:15 | 显示全部楼层
能给个链接吗?我想深入了解一下。

0

主题

48

回帖

95

积分

注册会员

积分
95
发表于 2024-5-23 21:11:13 | 显示全部楼层
太棒了!感谢分享这个信息!

1

主题

61

回帖

145

积分

注册会员

积分
145
发表于 2024-6-2 13:55:17 | 显示全部楼层
我完全同意你的观点

1

主题

58

回帖

140

积分

注册会员

积分
140
发表于 2024-6-11 00:02:38 | 显示全部楼层
666666666666666666

1

主题

57

回帖

134

积分

注册会员

积分
134
发表于 2024-8-17 19:37:48 | 显示全部楼层
我完全同意你的观点

1

主题

46

回帖

116

积分

注册会员

积分
116
发表于 2024-8-23 12:12:29 | 显示全部楼层
友善的讨论氛围是非常重要的。
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

手机版|小黑屋|爱云论坛 - d.taiji888.cn - 技术学习 免费资源分享 ( 蜀ICP备2022010826号 )|天天打卡

GMT+8, 2024-11-15 08:55 , Processed in 0.082726 second(s), 26 queries .

Powered by i云网络 Licensed

© 2023-2028 正版授权

快速回复 返回顶部 返回列表