import torch
from .helper import subsequent_mask
def evaluate(config, input_seq, sp_model, model, device):
# config:
# max_len: max input_seq length is 12.
# 1. tokenize input
# input_seq: str -> ids: list(3)
ids = sp_model.EncodeAsIds(input_seq)
# 2. padding
# (1, 12): [PAD, PAD, ...]
src = torch.zeros(1, config.max_len).fill_(config.pad_id)
# (1, 12): [9, 30442, 11, PAD, PAD, ...]
src[:, :len(ids)] = torch.LongTensor(ids)
# 3. masking
# (1, 12): [True, True, True, False, False, ...]
# False at special tokens
src_mask = src != config.pad_id
src, src_mask = src.long().to(device), src_mask.to(device)
# 4. encoding
# mem: (1, 12, 768)
# encode(src, attention_mask)
mem = model.encode(src, src_mask)
# 5. generate outputs by decoding
# ys: [BOS]
ys = torch.ones(1, 1).fill_(config.bos_id).long().to(device)
with torch.no_grad():
# while token is <EOF>
for i in range(config.max_len - 1):
# ys: (1, i): [BOS, x, y, ...]
# target_mask: (1, i, i)
# target_mask[0, :, :]: (i, i) is a lower triangle matrix L.
# in subsequent_mask(size), make Mat: (size, size) == 0 so looks like,
# tensor:[[
# [1, 0, 0, 0, ...],
# [1, 1, 0, 0, ...],
# [1, 1, 1, 0, ...],
# ]]
# masked special tokens(BOS)
target_mask = subsequent_mask(ys.size(1)).type_as(ys)
print(f'[{i}]: target mask: {target_mask.size()} {target_mask}')
# decode(mem, src_mask, target, target_mask)
# out: (1, i, 768)
out = model.decode(mem, src_mask, ys, target_mask)
print(f'out[:. -1]: {(out[:, -1]).size()}')
# out[:, -1]: (1, 768) -> prob: (1, 32000)
prob = model.generate(out[:, -1])
# next_word: (1, 1)
_, next_word = torch.max(prob, dim=1)
# scalarize: next_word
next_token_id = next_word[0]
if next_token_id == config.eos_id:
break
# n: (1, 1)
n = torch.ones(1, 1).type_as(ys).fill_(next_token_id).long()
# ys: (1, i) +@dim1 n: (1, 1)
ys = torch.cat([ys, n], dim=1)
print('\n')
# ys: (1, len(ids) + 1) -> list(len(ids))
# trim BOS @[0]
ys = ys.view(-1).detach().cpu().numpy().tolist()[1:]
# decode ids
output = ''.join([sp_model.IdToPiece(i) for i in ys]).replace('▁', '')
# generated response
print(output)