1_TiDoCLpkrGCrripMqPNcWg.webp

#트랜스포머 모델 구성
import math
import torch
from torch import nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, emb_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

class Seq2SeqTransformer(nn.Module):
    def __init__(
        self,
        num_encoder_layers,
        num_decoder_layers,
        emb_size,
        max_len,
        nhead,
        src_vocab_size,
        tgt_vocab_size,
        dim_feedforward,
        dropout=0.1,
		    ):
		    #TokenEmbedding 클래스로 소스 데이터와 입력 데이터를 입력 임베딩으로 변환하여 src_tok_emb와 tgt_tok_emb를 생성
		    #소스와 타깃 데이터의 어휘 사전 크기를 입력받아 트랜스포머 임베딩 크기로 변환한다.
        super().__init__()
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        #PositionalEncoding을 적용해 트랜스포머 블록에 입력
        self.positional_encoding = PositionalEncoding(
            d_model=emb_size, max_len=max_len, dropout=dropout
        )
        #트랜스포머 블록: 파이토치 제공 트랜스포머 클래스 적용
        self.transformer = nn.Transformer(
            d_model=emb_size,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )
        #순방향 메서드 마지막에 적용되는 generator
        #마지막 트랜스포머 디코더 블록에서 산출되는 벡터를 선형 변환해 어휘 사전에 대한 로짓(Logit)을 생성
        self.generator = nn.Linear(emb_size, tgt_vocab_size)

    def forward(
        self,
        src,
        trg,
        src_mask,
        tgt_mask,
        src_padding_mask,
        tgt_padding_mask,
        memory_key_padding_mask,
    ):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            memory_mask=None,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )
        return self.generator(outs)

    def encode(self, src, src_mask):
        return self.transformer.encoder(
            self.positional_encoding(self.src_tok_emb(src)), src_mask
        )

    def decode(self, tgt, memory, tgt_mask):
        return self.transformer.decoder(
            self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask
        )

Transformer의 학습 과정

1_S7DWA2VT5de28jm6AtDq-g.webp

  1. 입력 시퀀스(“I am good”)을 임베딩 후 Positional Encoding을 더한 후 Encoder로 입력한다.
  2. 입력 시퀀스는 Encoder Stack를 거치면 인코딩된 표현으로 변환된다.
  3. 타겟 시퀀스의 시작에 Start of Sentence 토큰을 알리는”<sos>”을 포함한 타겟 시퀀스에 임베딩 후 Positional Encoding을 더한 후 Decoder로 입력한다.
  4. Decoder Stack은 Encoder Stack에 의해 인코딩된 표현를 이용하여 타겟 시퀀스의 인코딩된 표현을 생성한다.
  5. Linear 레이어와 Softmax 레이어를 통과하면 단어의 확률을 계산하여 최종 출력 시퀀스로 변환한다.
  6. Transformer의 Loss 함수는 5단계에서 생성된 출력 시퀀스와 학습 데이터(“Je vais bien”)과 비교하고 back-propagation를 통해 학습한다.

Transformer의 추론 과정

Transformer의 추론 과정에는 Encoder stack에 전달되는 입력 시퀀스만 있고 Decoder stack에 입력할 타겟 시퀀스없이 출력 시퀀스를 생성하는 것이 목표이다.

Seq2Seq와 비교하였을 때, 이전 단계에 생성된 출력 시퀀스를 Decoder stack의 입력으로 사용하는 것은 동일하나 Seq2Seq가 이전 단계에서 생성된 단어를 전달하는 반면, Transformer는 누적된 전체 출력 시퀀스를 전달한다.

  1. 학습과정과 동일
  2. 학습과정과 동일
  3. 타겟 시퀀스가 없이 <sos> 시퀀스만을 임베딩 후 Positional Encoding을 더한 후 Decoder로 입력한다.
  4. 학습과정과 동일
  5. 학습과정과 동일