BERT ChatBot の推論部分のコードがわからなかったからテンソルの形をみた

元のコード

import torch
 
from .helper import subsequent_mask
 
def evaluate(config, input_seq, sp_model, model, device):
    # config:
    #   max_len: max input_seq length is 12.
 
    # 1. tokenize input
    # input_seq: str -> ids: list(3)
    ids = sp_model.EncodeAsIds(input_seq)
 
    # 2. padding
    # (1, 12): [PAD, PAD, ...]
    src = torch.zeros(1, config.max_len).fill_(config.pad_id)
    # (1, 12): [9, 30442, 11, PAD, PAD, ...]
    src[:, :len(ids)] = torch.LongTensor(ids)
 
    # 3. masking
    # (1, 12): [True, True, True, False, False, ...]
    # False at special tokens
    src_mask = src != config.pad_id
    src, src_mask = src.long().to(device), src_mask.to(device)
 
    # 4. encoding
    # mem: (1, 12, 768)
    # encode(src, attention_mask)
    mem = model.encode(src, src_mask)
 
    # 5. generate outputs by decoding
    # ys: [BOS]
    ys = torch.ones(1, 1).fill_(config.bos_id).long().to(device)
    with torch.no_grad():
        # while token is <EOF>
        for i in range(config.max_len - 1):
            # ys: (1, i): [BOS, x, y, ...]
            # target_mask: (1, i, i)
            # target_mask[0, :, :]: (i, i) is a lower triangle matrix L.
            # in subsequent_mask(size), make Mat: (size, size) == 0 so looks like,
            # tensor:[[
            # [1, 0, 0, 0, ...],
            # [1, 1, 0, 0, ...],
            # [1, 1, 1, 0, ...],
            # ]]
            # masked special tokens(BOS)
            target_mask = subsequent_mask(ys.size(1)).type_as(ys)
            print(f'[{i}]: target mask: {target_mask.size()} {target_mask}')
            # decode(mem, src_mask, target, target_mask)
            # out: (1, i, 768)
            out = model.decode(mem, src_mask, ys, target_mask)
            print(f'out[:. -1]: {(out[:, -1]).size()}')
            # out[:, -1]: (1, 768) -> prob: (1, 32000)
            prob = model.generate(out[:, -1])
            # next_word: (1, 1)
            _, next_word = torch.max(prob, dim=1)
            # scalarize: next_word
            next_token_id = next_word[0]
            if next_token_id == config.eos_id:
                break
            # n: (1, 1)
            n = torch.ones(1, 1).type_as(ys).fill_(next_token_id).long()
            # ys: (1, i) +@dim1 n: (1, 1)
            ys = torch.cat([ys, n], dim=1)
            print('\n')
            
    # ys: (1, len(ids) + 1) -> list(len(ids))
    # trim BOS @[0]
    ys = ys.view(-1).detach().cpu().numpy().tolist()[1:]
    # decode ids
    output = ''.join([sp_model.IdToPiece(i) for i in ys]).replace('▁', '')
    # generated response
    print(output)