Skip to content

LSTM 单词语言模型上的动态量化(beta)

原文:https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html

作者James Reed

编辑Seth Weidman

简介

量化涉及将模型的权重和激活从float转换为int,这可以导致模型尺寸更小,推断速度更快,而对准确率的影响很小。

在本教程中,我们将最简单的量化形式-动态量化应用于基于 LSTM 的下一个单词预测模型,紧紧遵循 PyTorch 示例中的单词语言模型

# imports
import os
from io import open
import time

import torch
import torch.nn as nn
import torch.nn.functional as F

1.定义模型

在这里,我们根据词语言模型示例中的模型定义 LSTM 模型架构。

class LSTMModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        return decoded, hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

2.加载文本数据

接下来,我们再次根据单词模型示例对预处理,将 Wikitext-2 数据集加载到语料库中。

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)

class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r', encoding="utf8") as f:
            for line in f:
                words = line.split() + ['<eos>']
                for word in words:
                    self.dictionary.add_word(word)

        # Tokenize file content
        with open(path, 'r', encoding="utf8") as f:
            idss = []
            for line in f:
                words = line.split() + ['<eos>']
                ids = []
                for word in words:
                    ids.append(self.dictionary.word2idx[word])
                idss.append(torch.tensor(ids).type(torch.int64))
            ids = torch.cat(idss)

        return ids

model_data_filepath = 'data/'

corpus = Corpus(model_data_filepath + 'wikitext-2')

3.加载预先训练的模型

这是一本有关动态量化的教程,这是在训练模型后应用的一种量化技术。 因此,我们将简单地将一些预训练的权重加载到此模型架构中; 这些权重是通过使用单词语言模型示例中的默认设置训练五个周期而获得的。

ntokens = len(corpus.dictionary)

model = LSTMModel(
    ntoken = ntokens,
    ninp = 512,
    nhid = 256,
    nlayers = 5,
)

model.load_state_dict(
    torch.load(
        model_data_filepath + 'word_language_model_quantize.pth',
        map_location=torch.device('cpu')
        )
    )

model.eval()
print(model)

出:

LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): LSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): Linear(in_features=256, out_features=33278, bias=True)
)

现在,我们生成一些文本以确保预先训练的模型能够正常工作-与以前类似,我们在此处遵循

input_ = torch.randint(ntokens, (1, 1), dtype=torch.long)
hidden = model.init_hidden(1)
temperature = 1.0
num_words = 1000

with open(model_data_filepath + 'out.txt', 'w') as outf:
    with torch.no_grad():  # no tracking history
        for i in range(num_words):
            output, hidden = model(input_, hidden)
            word_weights = output.squeeze().div(temperature).exp().cpu()
            word_idx = torch.multinomial(word_weights, 1)[0]
            input_.fill_(word_idx)

            word = corpus.dictionary.idx2word[word_idx]

            outf.write(str(word.encode('utf-8')) + ('\n' if i % 20 == 19 else ' '))

            if i % 100 == 0:
                print('| Generated {}/{} words'.format(i, 1000))

with open(model_data_filepath + 'out.txt', 'r') as outf:
    all_output = outf.read()
    print(all_output)

出:

| Generated 0/1000 words
| Generated 100/1000 words
| Generated 200/1000 words
| Generated 300/1000 words
| Generated 400/1000 words
| Generated 500/1000 words
| Generated 600/1000 words
| Generated 700/1000 words
| Generated 800/1000 words
| Generated 900/1000 words
b'broadcaster' b'good' b',' b'which' b'provided' b'for' b'a' b'vignettes' b'socially' b'and' b'the' b'FIA' b"'s" b'ad' b'.' b'The' b'state' b'into' b'this' b'position'
b'is' b'in' b'account' b'of' b'a' b'wide' b'Domonia' b'<unk>' b',' b'fallen' b'to' b'for' b'the' b'types' b'of' b'<unk>' b'developers' b'being' b'entertaining' b'.'
b'<eos>' b'The' b'Claus' b'II' b'(' b'The' b'Book' b'of' b'Karnataka' b',' b'2' b'/' b'10' b')' b'was' b'released' b'by' b'British' b'@-@' b'Irish'
b'ruler' b'arriving' b'on' b'the' b'winter' b'of' b'its' b'championship' b'orbit' b'.' b'In' b'early' b'spring' b'roles' b'dismay' b'when' b'he' b'replaced' b'by' b'a'
b'religious' b'park' b',' b'when' b'it' b'features' b'flowers' b'they' b'do' b'populist' b'.' b'temperatures' b'attempted' b'to' b'have' b'trouble' b'met' b',' b'<unk>' b','
b'and' b'karaoke' b'leads' b'to' b'some' b'return' b'up' b'as' b'or' b'seated' b'.' b'The' b'remainder' b'of' b'w' b'voltage' b'contains' b'Allah' b'in' b'the'
b'series' b'to' b'infiltrate' b'disappeared' b'.' b'Though' b'it' b'comes' b'into' b'his' b'Shinnok' b"'s" b'history' b',' b'they' b'may' b'sometimes' b'7' b'@-@' b'April'
b',' b'roughly' b'7' b'%' b'of' b'50' b'mph' b'(' b'4' b'@.@' b'8' b'in' b')' b'while' b'males' b'have' b'put' b'except' b'far' b'as'
b'alkaline' b'@-@' b'up' b'.' b'<eos>' b'Electrical' b'medical' b'rings' b'were' b'always' b'published' b'.' b'<eos>' b'Based' b'on' b'2' b'November' b',' b'Idaho' b'can'
b'be' b'estimated' b'cooking' b'and' b'<unk>' b',' b'while' b'no' b',' b'thin' b'drugs' b'was' b'poor' b'to' b'each' b'area' b'.' b'It' b'has' b'not'
b'campaigned' b'those' b'of' b'the' b'most' b'potent' b'population' b'of' b'leaves' b'in' b'all' b'condition' b',' b'because' b'they' b'were' b'forced' b'to' b'die' b'in'
b'bhandara' b'<unk>' b'that' b'culture' b'.' b'Almost' b'a' b'prose' b'plan' b',' b'there' b'have' b'been' b'only' b'clear' b',' b'it' b'occurs' b'.' b'<eos>'
b'The' b'kakapo' b'was' b'interpreted' b'on' b'1998' b'from' b'1955' b'and' b'played' b'in' b'<unk>' b',' b'Western' b'Asia' b'on' b'0' b'August' b'1966' b','
b'with' b'an' b'additional' b'population' b'that' b'Samuel' b'solemnly' b',' b'Chapman' b'sponsored' b'after' b'a' b'few' b'years' b'.' b'In' b'1990' b',' b'prominent' b'areas'
b'believe' b'that' b'as' b'being' b'an' b'rural' b'planet' b',' b'they' b'is' b'neglected' b'as' b'to' b'be' b'changed' b'.' b'Congress' b'This' b'well' b'"'
b'was' b'run' b'by' b'<unk>' b',' b'Waldemar' b'Greenwood' b'.' b'170' b'have' b'just' b'in' b'place' b',' b'he' b'overruled' b'.' b'The' b'1966' b'race'
b'is' b'a' b'embodies' b'state' b'of' b'Viking' b'or' b'most' b'generation' b',' b'not' b'in' b'the' b'codes' b'of' b'all' b'other' b'alignment' b'musical' b'politicians'
b'.' b'No' b'system' b'have' b'participated' b'on' b'3' b'to' b'9' b'%' b'of' b'any' b'urine' b',' b'with' b'both' b'drawings' b'and' b'significantly' b'towards'
b'his' b'deteriorating' b'and' b'poverty' b'.' b'As' b'a' b'rust' b',' b'contains' b'other' b'compositions' b'that' b'must' b'be' b'beneficial' b'by' b'overnight' b'or' b'fluid'
b',' b'u' b'organizations' b'can' b'seek' b'mild' b'late' b'down' b'on' b'a' b'broadside' b'and' b'leads' b'to' b'its' b'cycle' b'.' b'For' b'example' b','
b'1137' b',' b'snowmelt' b'and' b'<unk>' b'\xe2\x80\x94' b'a' b'variety' b'of' b'dealt' b';' b'Species' b'(' b'with' b'a' b'reduction' b'of' b'prohibitions' b')' b','
b'<unk>' b'exploration' b',' b'<unk>' b'an' b'fuel' b'eye' b'of' b'purple' b'trees' b',' b'was' b'shown' b'west' b'.' b'chased' b'Jack' b'of' b'claws' b','
b'his' b'vertex' b'states' b'that' b'they' b',' b'in' b'1922' b',' b'was' b'killed' b'.' b'<eos>' b'There' b'have' b'been' b'official' b'concerns' b'of' b'Boat'
b'Kerry' b'including' b'L\xc3\xaa' b'\xe3\x80\x89' b'and' b'<unk>' b'A' b'Forest' b',' b'"' b'<unk>' b',' b'because' b'<unk>' b',' b'and' b'sometimes' b'encounters' b'like' b'I'
b"'ve" b'been' b'<unk>' b'.' b'"' b'<unk>' b'Hunter' b'pathway' b'writes' b'it' b'entering' b'the' b'second' b'.' b'The' b'kakapo' b'is' b'gems' b'used' b'after'
b'died' b'from' b'two' b'games' b'in' b'six' b'<unk>' b',' b'her' b'feature' b'and' b'called' b'"' b'mercenaries' b'"' b',' b'which' b'supported' b'by' b'the'
b'Selective' b'Race' b'.' b'"' b'<eos>' b'Bono' b'Dutch' b'struggles' b'to' b'the' b'species' b'<unk>' b',' b'especially' b'crusaders' b'I' b'lives' b'process' b',' b'but'
b'Constantin' b'approximate' b'and' b'character' b'or' b'so' b'.' b'There' b'have' b'numerous' b'pale' b'dioceses' b'as' b'a' b'resistant' b';' b'the' b'Inn' b'Comic' b'@-@'
b'white' b'individuals' b',' b'its' b'flat' b',' b'<unk>' b'and' b'correct' b',' b'in' b'which' b'they' b'felt' b'.' b'In' b'the' b'arms' b',' b'the'
b'original' b'occasion' b'about' b'Spanish' b'sites' b'all' b'(' b'millionaire' b'lay' b';' b'or' b'160' b'@-@' b'mosquitoes' b')' b'v' b'<unk>' b'(' b'c' b')'
b'.' b'The' b'bird' b'is' b'extremely' b'paved' b',' b'and' b'they' b'are' b'claimed' b'to' b'wedding' b'the' b'<unk>' b'of' b'Excellence' b',' b'and' b'an'
b'extinct' b'composite' b',' b'cute' b'outside' b'<unk>' b'.' b'This' b'may' b'be' b'seen' b'by' b'the' b'Seer' b'that' b'Tempest' b'"' b'comes' b'"' b'over'
b'a' b'bright' b'judicial' b'guitar' b',' b'which' b'describes' b',' b'and' b'tend' b'to' b'be' b'seen' b'.' b'<eos>' b'<eos>' b'=' b'=' b'Conservation' b'for'
b'contraception' b'=' b'=' b'<eos>' b'<eos>' b'Grieco' b'Island' b'is' b'a' b'eventually' b'scale' b'word' b'to' b'a' b'tropical' b'storm' b',' b'based' b'in' b'a'
b'pre' b'\xe2\x80\x93' b'9' b'lead' b',' b'a' b'forces' b'after' b'a' b'additional' b',' b'grey' b'substance' b',' b'Metro' b',' b'background' b',' b'and' b'cooperate'
b'with' b'its' b'overly' b'overview' b',' b'so' b'the' b'heaviest' b'route' b',' b'and' b'\xc2\xb3' b'.' b'portion' b'may' b'occur' b'this' b'other' b'up' b'an'
b'<unk>' b'break' b',' b'then' b'or' b'deep' b'distinct' b'or' b'female' b'offspring' b',' b'but' b'even' b'understand' b'.' b'Following' b'God' b'(' b'no' b'nervous'
b'image' b'from' b'complaints' b')' b',' b'the' b'player' b'represents' b'three' b'or' b'over' b'9' b'\xc2\xb0' b'large' b'(' b'five' b'weeks' b'of' b'many' b'cats'
b')' b',' b'as' b'it' b'targets' b'for' b'the' b'second' b'female' b'together' b'.' b'159' b',' b'it' b'also' b'spend' b'bold' b'markets' b'and' b'its'
b'players' b'powers' b',' b'dubbed' b'those' b'of' b'lengths' b'.' b'Most' b'are' b'arrow' b'could' b'be' b'noticed' b'involving' b'they' b'fall' b'.' b'On' b'FAU'
b"'s" b'only' b'lifetime' b'she' b'treated' b'or' b'their' b'apparent' b'soaring' b'proposition' b'has' b'5th' b'of' b'those' b'eye' b',' b'but' b'knows' b'in' b'a'
b'<unk>' b'Network' b';' b'which' b'of' b'that' b'reality' b'or' b'artificial' b'when' b'struggling' b'Bungie' b'is' b'successful' b'.' b'The' b'<unk>' b'sound' b'of' b'frontier'
b'ahead' b'for' b'damage' b'came' b'on' b',' b'so' b'the' b'first' b'series' b'funded' b'by' b'its' b'bowls' b',' b'a' b'chant' b'.' b'They' b'may'
b'be' b'used' b'Pongsak' b'or' b'occasionally' b'protected' b'them' b'.' b'Fingal' b'cylindrical' b'conspired' b'on' b'a' b'variety' b'of' b'prey' b',' b'<unk>' b',' b'Zach'
b',' b'and' b'young' b'possessing' b'Westland' b'valleys' b'.' b'Otherwise' b',' b'I' b'do' b'at' b'them' b'in' b'first' b'@-@' b'season' b'woodland' b',' b'where'
b'they' b'weighed' b'them' b'to' b'correct' b'a' b'list' b'of' b'other' b'birds' b'.' b'Another' b'theme' b'where' b'or' b',' b'it' b'is' b'a' b'appropriate'
b'source' b',' b'this' b'competed' b'in' b'integral' b'Waiouru' b'alone' b',' b'the' b'pathways' b'under' b'Aravind' b',' b'and' b'others' b',' b'instead' b'of' b'westward'
b',' b'as' b'they' b'are' b'quarters' b'and' b'caused' b'in' b'males' b'.' b'Once' b'selective' b'centered' b',' b'they' b'threats' b'were' b'Zuniceratops' b'.' b'Although'
b'the' b'most' b'spots' b'replication' b'became' b'a' b'fragile' b'pointer' b'(' b'a' b'pair' b'of' b'<unk>' b')' b',' b'strongly' b'"' b'mammals' b'"' b','
b'which' b'give' b'Powderfinger' b'to' b'persecution' b'.' b'Other' b'conifers' b'but' b'even' b'only' b'swallow' b'so' b'every' b'symbols' b'of' b'Manders' b',' b'in' b'massive'

它不是 GPT-2,但看起来该模型已开始学习语言结构!

我们几乎准备好演示动态量化。 我们只需要定义一些辅助函数:

bptt = 25
criterion = nn.CrossEntropyLoss()
eval_batch_size = 1

# create test data set
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    return data.view(bsz, -1).t().contiguous()

test_data = batchify(corpus.test, eval_batch_size)

# Evaluation functions
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

def repackage_hidden(h):
  """Wraps hidden states in new Tensors, to detach them from their history."""

  if isinstance(h, torch.Tensor):
      return h.detach()
  else:
      return tuple(repackage_hidden(v) for v in h)

def evaluate(model_, data_source):
    # Turn on evaluation mode which disables dropout.
    model_.eval()
    total_loss = 0.
    hidden = model_.init_hidden(eval_batch_size)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            output, hidden = model_(data, hidden)
            hidden = repackage_hidden(hidden)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

4.测试动态量化

最后,我们可以在模型上调用torch.quantization.quantize_dynamic! 特别,

  • 我们指定我们希望对模型中的nn.LSTMnn.Linear模块进行量化
  • 我们指定希望将权重转换为int8
import torch.quantization

quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
print(quantized_model)

出:

LSTMModel(
  (drop): Dropout(p=0.5, inplace=False)
  (encoder): Embedding(33278, 512)
  (rnn): DynamicQuantizedLSTM(512, 256, num_layers=5, dropout=0.5)
  (decoder): DynamicQuantizedLinear(in_features=256, out_features=33278, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

该模型看起来相同; 这对我们有什么好处? 首先,我们看到模型尺寸显着减小:

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

print_size_of_model(model)
print_size_of_model(quantized_model)

出:

Size (MB): 113.945726
Size (MB): 79.739984

其次,我们看到了更快的推理时间,而评估损失没有差异:

注意:由于量化模型运行单线程,因此用于单线程比较的线程数为 1。

torch.set_num_threads(1)

def time_model_evaluation(model, test_data):
    s = time.time()
    loss = evaluate(model, test_data)
    elapsed = time.time() - s
    print('''loss: {0:.3f}\nelapsed time (seconds): {1:.1f}'''.format(loss, elapsed))

time_model_evaluation(model, test_data)
time_model_evaluation(quantized_model, test_data)

出:

loss: 5.167
elapsed time (seconds): 251.3
loss: 5.168
elapsed time (seconds): 166.3

在没有量化的情况下在 MacBook Pro 上本地运行此操作,推理大约需要 200 秒,而量化则只需大约 100 秒。

总结

动态量化可能是减小模型大小的简单方法,而对精度的影响有限。

谢谢阅读! 与往常一样,我们欢迎您提供反馈,因此,如果有任何问题,请在这里创建一个 ISSUE

脚本的总运行时间:(7 分钟 3.126 秒)

下载 Python 源码:dynamic_quantization_tutorial.py

下载 Jupyter 笔记本:dynamic_quantization_tutorial.ipynb

由 Sphinx 画廊生成的画廊



回到顶部