Skip to content

修剪教程

原文: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

注意

单击此处的下载完整的示例代码

作者Michela Paganini

最新的深度学习技术依赖于难以部署的过度参数化模型。 相反,已知生物神经网络使用有效的稀疏连通性。 为了减少内存,电池和硬件消耗,同时又不牺牲精度,在设备上部署轻量级模型并通过私有设备内计算来确保私密性,确定通过减少模型中参数数量来压缩模型的最佳技术很重要。 在研究方面,修剪用于研究参数过度配置和参数不足网络之间学习动态的差异,以研究幸运稀疏子网络和初始化(“ 彩票”)作为破坏性对象的作用。 神经结构搜索技术等等。

在本教程中,您将学习如何使用torch.nn.utils.prune稀疏神经网络,以及如何扩展它以实现自己的自定义修剪技术。

要求

"torch>=1.4.0a0+8e8a5e0"

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

建立模型

在本教程中,我们使用 LeCun 等人,1998 年的 LeNet 体系结构。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查模块

让我们检查一下 LeNet 模型中的(未经修剪的)conv1层。 目前它将包含两个参数weightbias,并且没有缓冲区。

module = model.conv1
print(list(module.named_parameters()))

出:

[('weight', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]

print(list(module.named_buffers()))

Out:

[]

修剪模块

要修剪模块(在此示例中,为 LeNet 架构的conv1层),请首先从torch.nn.utils.prune中可用的那些技术中选择一种修剪技术(或通过子类化BasePruningMethod实现您自己的)。 然后,指定模块和该模块中要修剪的参数的名称。 最后,使用所选修剪技术所需的适当关键字参数,指定修剪参数。

在此示例中,我们将在conv1层中名为weight的参数中随机修剪 30%的连接。 模块作为第一个参数传递给函数; name使用其字符串标识符在该模块内标识参数; amount表示与修剪的连接百分比(如果它是介于 0 和 1 之间的浮点数),或者表示与修剪的连接的绝对数量(如果它是非负整数)。

prune.random_unstructured(module, name="weight", amount=0.3)

修剪是通过从参数中删除weight并将其替换为名为weight_orig的新参数(即,将"_orig"附加到初始参数name)来进行的。 weight_orig存储未修剪的张量版本。 bias未修剪,因此它将保持完整。

print(list(module.named_parameters()))

Out:

[('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True))]

通过以上选择的修剪技术生成的修剪掩码将保存为名为weight_mask的模块缓冲区(即,将"_mask"附加到初始参数name)。

print(list(module.named_buffers()))

Out:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],

        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],

        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0'))]

为了使前向通过不更改即可工作,需要存在weight属性。 torch.nn.utils.prune中实现的修剪技术计算权重的修剪版本(通过将掩码与原始参数组合)并将其存储在属性weight中。 注意,这不再是module的参数,现在只是一个属性。

print(module.weight)

Out:

tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.0000, -0.2106],
          [ 0.1776, -0.1845, -0.0000],
          [-0.0708,  0.0000,  0.3095]]],

        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.0000, -0.0000],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.0000],
          [ 0.2159, -0.1725,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

最后,使用 PyTorch 的forward_pre_hooks在每次向前传递之前应用修剪。 具体来说,当修剪module时(如我们在此处所做的那样),它将为与之关联的每个参数获取forward_pre_hook进行修剪。 在这种情况下,由于到目前为止我们只修剪了名称为weight的原始参数,因此只会出现一个钩子。

print(module._forward_pre_hooks)

Out:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>)])

为了完整起见,我们现在也可以修剪bias,以查看module的参数,缓冲区,挂钩和属性如何变化。 仅出于尝试另一种修剪技术的目的,在此我们按 L1 范数修剪偏差中的 3 个最小条目,如l1_unstructured修剪功能中所实现的。

prune.l1_unstructured(module, name="bias", amount=3)

现在,我们希望命名的参数同时包含weight_orig(从前)和bias_orig。 缓冲区将包括weight_maskbias_mask。 两个张量的修剪版本将作为模块属性存在,并且该模块现在将具有两个forward_pre_hooks

print(list(module.named_parameters()))

Out:

[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]

print(list(module.named_buffers()))

Out:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],

        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],

        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

print(module.bias)

Out:

tensor([-0.0000, -0.0000, -0.2656, -0.1519, -0.0000,  0.1425], device='cuda:0',
       grad_fn=<MulBackward0>)

print(module._forward_pre_hooks)

Out:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f1e6c425550>)])

迭代修剪

一个模块中的同一参数可以被多次修剪,各种修剪调用的效果等于串联应用的各种蒙版的组合。 PruningContainercompute_mask方法可处理新遮罩与旧遮罩的组合。

例如,假设我们现在要进一步修剪module.weight,这一次是使用沿着张量的第 0 轴的结构化修剪(第 0 轴对应于卷积层的输出通道,并且conv1的维数为 6) ,基于渠道的 L2 规范。 这可以通过ln_structuredn=2dim=0功能来实现。

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)

Out:

tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],

        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],

        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],

        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

现在,对应的钩子将为torch.nn.utils.prune.PruningContainer类型,并将存储应用于weight参数的修剪历史。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

Out:

[<torch.nn.utils.prune.RandomUnstructured object at 0x7f1e6c425400>, <torch.nn.utils.prune.LnStructured object at 0x7f1e6c4259b0>]

序列化修剪的模型

所有相关的张量,包括掩码缓冲区和用于计算修剪的张量的原始参数,都存储在模型的state_dict中,因此可以根据需要轻松地序列化和保存。

print(model.state_dict().keys())

Out:

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

删除修剪重新参数化

要使修剪永久化,请删除weight_origweight_mask的重新参数化,然后删除forward_pre_hook,我们可以使用torch.nn.utils.pruneremove功能。 请注意,这不会撤消修剪,好像从未发生过。 它只是通过将参数weight重新分配为模型参数(修剪后的版本)来使其永久不变。

删除重新参数化之前:

print(list(module.named_parameters()))

Out:

[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],

        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],

        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],

        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]

print(list(module.named_buffers()))

Out:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],

        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],

        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],

        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],

        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],

        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

print(module.weight)

Out:

tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],

        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],

        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],

        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

删除重新参数化后:

prune.remove(module, 'weight')
print(list(module.named_parameters()))

Out:

[('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],

        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],

        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],

        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],

        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],

        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0', requires_grad=True))]

print(list(module.named_buffers()))

Out:

[('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

修剪模型中的多个参数

通过指定所需的修剪技术和参数,我们可以轻松地修剪网络中的多个张量,也许根据它们的类型,如在本示例中将看到的那样。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

Out:

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全球修剪

到目前为止,我们仅研究了通常被称为“局部”修剪的方法,即通过比较每个条目的统计信息(权重,激活度,梯度等)来逐一修剪模型中的张量的做法。 到该张量中的其他条目。 但是,一种常见且可能更强大的技术是通过删除(例如)删除整个模型中最低的 20%的连接,而不是删除每一层中最低的 20%的连接来一次修剪模型。 这很可能导致每个层的修剪百分比不同。 让我们看看如何使用torch.nn.utils.prune中的global_unstructured进行操作。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

现在,我们可以检查在每个修剪参数中引起的稀疏性,该稀疏性将不等于每层中的 20%。 但是,全球稀疏度将(大约)为 20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100\. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

Out:

Sparsity in conv1.weight: 7.41%
Sparsity in conv2.weight: 9.49%
Sparsity in fc1.weight: 22.00%
Sparsity in fc2.weight: 12.28%
Sparsity in fc3.weight: 9.76%
Global sparsity: 20.00%

使用自定义修剪功能扩展torch.nn.utils.prune

要实现自己的修剪功能,您可以通过继承BasePruningMethod基类来扩展nn.utils.prune模块,这与所有其他修剪方法一样。 基类为您实现以下方法:__call__apply_maskapplypruneremove。 除了某些特殊情况外,您不必为新的修剪技术重新实现这些方法。 但是,您将必须实现__init__(构造函数)和compute_mask(有关如何根据修剪技术的逻辑为给定张量计算掩码的说明)。 另外,您将必须指定此技术实现的修剪类型(支持的选项为globalstructuredunstructured)。 需要确定在迭代应用修剪的情况下如何组合蒙版。 换句话说,当修剪预修剪的参数时,当前的修剪技术应作用于参数的未修剪部分。 指定PRUNING_TYPE将使PruningContainer(处理修剪蒙版的迭代应用)正确识别要修剪的参数。

例如,假设您要实施一种修剪技术,以修剪张量中的所有其他条目(或者-如果先前已修剪过张量,则在张量的其余未修剪部分中)。 这将是PRUNING_TYPE='unstructured',因为它作用于层中的单个连接,而不作用于整个单元/通道('structured'),或作用于不同的参数('global')。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

现在,要将其应用于nn.Module中的参数,还应该提供一个简单的函数来实例化该方法并将其应用。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

试试吧!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)

Out:

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

脚本的总运行时间:(0 分钟 0.146 秒)

Download Python source code: pruning_tutorial.py Download Jupyter notebook: pruning_tutorial.ipynb

由狮身人面像画廊生成的画廊



回到顶部