카트폴 환경에서 DQN 에이전트 구현

#DQN - 카트폴
#SAINT Lab. Q1 [강화학습]
#60201969 이유현 [2024.01.29]
import gym #OpenAI GYM 라이브러리
import collections #선입선출의 특성을 갖고 있는 리플레이 버퍼 구현
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

#Hyperparameters: 정답이 아닌 하나의 예시
learning_rate = 0.0005
gamma         = 0.98
buffer_limit  = 50000
batch_size    = 32 #하나의 미니 배치 안에 32개의 데이터

class ReplayBuffer(): #5만 개의 최신 데이터를 저장해두었다가 batch_size만큼의 데이터 제공
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)
    
    def put(self, transition): #데이터를 버퍼에 넣어주는 함수
        self.buffer.append(transition)
    
    def sample(self, n): #버퍼에서 랜덤하게 32개의 데이터를 뽑아서 미니 배치를 구성해주는 함수
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], [] #하나의 데이터의 구성, done_mask: 종료 상태의 밸류를 마스킹
        
        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \\
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \\
               torch.tensor(done_mask_lst)
    
    def size(self):
        return len(self.buffer)

class Qnet(nn.Module): #Q밸류 네트워크
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x) #마지막에는 relu 사용 x: 마지막 아웃웃은 Q밸류이기 때문에 어느 값이든 가능(relu는 양수만 리턴)
        return x
      
    def sample_action(self, obs, epsilon): #실제로 행할 액션을 입실론-greedy 방식으로 선택
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon: #코인이 입실론보다 작으면 랜덤 액션
            return random.randint(0,1)
        else : #크면 Q값이 제일 큰 액션 선택
            return out.argmax().item()
            
def train(q, q_target, memory, optimizer): #학습 함수
    for i in range(10): #10개의 미니 배치를 뽑아 10번 업데이트: 1번 업데이트에 32개 데이터 사용 = 한 에피소드마다 320개의 데이터 사용
        s,a,r,s_prime,done_mask = memory.sample(batch_size) 

        q_out = q(s)
        q_a = q_out.gather(1,a) #실제 선택된 액션의 Q값
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1) #정답지 계산(q_target 네트워크 호출)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target) #loss값 계산
        
        optimizer.zero_grad()
        loss.backward() #그라디언트 계산
        optimizer.step() #Qnet의 파라미터 업데이트

def main():
    env = gym.make('CartPole-v1')
    q = Qnet()
    q_target = Qnet()
    q_target.load_state_dict(q.state_dict()) #Q네트워크의 파라미터들을 타깃 네트워크로 복사: 초기에는 동일
    memory = ReplayBuffer()

    print_interval = 20
    score = 0.0  
    optimizer = optim.Adam(q.parameters(), lr=learning_rate) #q 네트워크의 파라미터만 업데이트

    for n_epi in range(10000):
        epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
        s, _ = env.reset()
        done = False

        while not done:
            a = q.sample_action(torch.from_numpy(s).float(), epsilon)      
            s_prime, r, done, truncated, info = env.step(a)
            done_mask = 0.0 if done else 1.0
            memory.put((s,a,r/100.0,s_prime, done_mask))
            s = s_prime

            score += r
            if done:
                break
            
        if memory.size()>2000: #초기 2천개의 데이터 축적 이후부터 학습 시작(재사용으로 인해 학습이 치우치는 것 방지)
            train(q, q_target, memory, optimizer)

        if n_epi%print_interval==0 and n_epi!=0: #에피소드가 10개 끝날 때마다
            q_target.load_state_dict(q.state_dict()) #q 네트워크의 파라미터를 q_target 네트워크로 복사
            print("n_episode :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format( #가장 최근 10개 에피소드의 보상 총합의 평균을 프린트
                                                            n_epi, score/print_interval, memory.size(), epsilon*100))
            score = 0.0
    env.close()

if __name__ == '__main__':
    main()

<aside> 📢 파이토치 설치 실패로 실구현 보류

</aside>