
#트랜스포머 모델 구성
import math
import torch
from torch import nn
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[: x.size(0)]
return self.dropout(x)
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, emb_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size
def forward(self, tokens):
return self.embedding(tokens.long()) * math.sqrt(self.emb_size)
class Seq2SeqTransformer(nn.Module):
def __init__(
self,
num_encoder_layers,
num_decoder_layers,
emb_size,
max_len,
nhead,
src_vocab_size,
tgt_vocab_size,
dim_feedforward,
dropout=0.1,
):
#TokenEmbedding 클래스로 소스 데이터와 입력 데이터를 입력 임베딩으로 변환하여 src_tok_emb와 tgt_tok_emb를 생성
#소스와 타깃 데이터의 어휘 사전 크기를 입력받아 트랜스포머 임베딩 크기로 변환한다.
super().__init__()
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
#PositionalEncoding을 적용해 트랜스포머 블록에 입력
self.positional_encoding = PositionalEncoding(
d_model=emb_size, max_len=max_len, dropout=dropout
)
#트랜스포머 블록: 파이토치 제공 트랜스포머 클래스 적용
self.transformer = nn.Transformer(
d_model=emb_size,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
)
#순방향 메서드 마지막에 적용되는 generator
#마지막 트랜스포머 디코더 블록에서 산출되는 벡터를 선형 변환해 어휘 사전에 대한 로짓(Logit)을 생성
self.generator = nn.Linear(emb_size, tgt_vocab_size)
def forward(
self,
src,
trg,
src_mask,
tgt_mask,
src_padding_mask,
tgt_padding_mask,
memory_key_padding_mask,
):
src_emb = self.positional_encoding(self.src_tok_emb(src))
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
outs = self.transformer(
src=src_emb,
tgt=tgt_emb,
src_mask=src_mask,
tgt_mask=tgt_mask,
memory_mask=None,
src_key_padding_mask=src_padding_mask,
tgt_key_padding_mask=tgt_padding_mask,
memory_key_padding_mask=memory_key_padding_mask
)
return self.generator(outs)
def encode(self, src, src_mask):
return self.transformer.encoder(
self.positional_encoding(self.src_tok_emb(src)), src_mask
)
def decode(self, tgt, memory, tgt_mask):
return self.transformer.decoder(
self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask
)

Transformer의 추론 과정에는 Encoder stack에 전달되는 입력 시퀀스만 있고 Decoder stack에 입력할 타겟 시퀀스없이 출력 시퀀스를 생성하는 것이 목표이다.
Seq2Seq와 비교하였을 때, 이전 단계에 생성된 출력 시퀀스를 Decoder stack의 입력으로 사용하는 것은 동일하나 Seq2Seq가 이전 단계에서 생성된 단어를 전달하는 반면, Transformer는 누적된 전체 출력 시퀀스를 전달한다.
Seq2Seq의 Decoder에 전달되는 출력 시퀀스 : “<sos>” → “Je” → “vais”
Transformer의 Decoder에 전달되는 출력 시퀀스 : “<sos>” → “<sos>Je” → “<sos>Je vais”
