#DQN - 카트폴
#SAINT Lab. Q1 [강화학습]
#60201969 이유현 [2024.01.29]
import gym #OpenAI GYM 라이브러리
import collections #선입선출의 특성을 갖고 있는 리플레이 버퍼 구현
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#Hyperparameters: 정답이 아닌 하나의 예시
learning_rate = 0.0005
gamma = 0.98
buffer_limit = 50000
batch_size = 32 #하나의 미니 배치 안에 32개의 데이터
class ReplayBuffer(): #5만 개의 최신 데이터를 저장해두었다가 batch_size만큼의 데이터 제공
def __init__(self):
self.buffer = collections.deque(maxlen=buffer_limit)
def put(self, transition): #데이터를 버퍼에 넣어주는 함수
self.buffer.append(transition)
def sample(self, n): #버퍼에서 랜덤하게 32개의 데이터를 뽑아서 미니 배치를 구성해주는 함수
mini_batch = random.sample(self.buffer, n)
s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], [] #하나의 데이터의 구성, done_mask: 종료 상태의 밸류를 마스킹
for transition in mini_batch:
s, a, r, s_prime, done_mask = transition
s_lst.append(s)
a_lst.append([a])
r_lst.append([r])
s_prime_lst.append(s_prime)
done_mask_lst.append([done_mask])
return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \\
torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \\
torch.tensor(done_mask_lst)
def size(self):
return len(self.buffer)
class Qnet(nn.Module): #Q밸류 네트워크
def __init__(self):
super(Qnet, self).__init__()
self.fc1 = nn.Linear(4, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x) #마지막에는 relu 사용 x: 마지막 아웃웃은 Q밸류이기 때문에 어느 값이든 가능(relu는 양수만 리턴)
return x
def sample_action(self, obs, epsilon): #실제로 행할 액션을 입실론-greedy 방식으로 선택
out = self.forward(obs)
coin = random.random()
if coin < epsilon: #코인이 입실론보다 작으면 랜덤 액션
return random.randint(0,1)
else : #크면 Q값이 제일 큰 액션 선택
return out.argmax().item()
def train(q, q_target, memory, optimizer): #학습 함수
for i in range(10): #10개의 미니 배치를 뽑아 10번 업데이트: 1번 업데이트에 32개 데이터 사용 = 한 에피소드마다 320개의 데이터 사용
s,a,r,s_prime,done_mask = memory.sample(batch_size)
q_out = q(s)
q_a = q_out.gather(1,a) #실제 선택된 액션의 Q값
max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1) #정답지 계산(q_target 네트워크 호출)
target = r + gamma * max_q_prime * done_mask
loss = F.smooth_l1_loss(q_a, target) #loss값 계산
optimizer.zero_grad()
loss.backward() #그라디언트 계산
optimizer.step() #Qnet의 파라미터 업데이트
def main():
env = gym.make('CartPole-v1')
q = Qnet()
q_target = Qnet()
q_target.load_state_dict(q.state_dict()) #Q네트워크의 파라미터들을 타깃 네트워크로 복사: 초기에는 동일
memory = ReplayBuffer()
print_interval = 20
score = 0.0
optimizer = optim.Adam(q.parameters(), lr=learning_rate) #q 네트워크의 파라미터만 업데이트
for n_epi in range(10000):
epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
s, _ = env.reset()
done = False
while not done:
a = q.sample_action(torch.from_numpy(s).float(), epsilon)
s_prime, r, done, truncated, info = env.step(a)
done_mask = 0.0 if done else 1.0
memory.put((s,a,r/100.0,s_prime, done_mask))
s = s_prime
score += r
if done:
break
if memory.size()>2000: #초기 2천개의 데이터 축적 이후부터 학습 시작(재사용으로 인해 학습이 치우치는 것 방지)
train(q, q_target, memory, optimizer)
if n_epi%print_interval==0 and n_epi!=0: #에피소드가 10개 끝날 때마다
q_target.load_state_dict(q.state_dict()) #q 네트워크의 파라미터를 q_target 네트워크로 복사
print("n_episode :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format( #가장 최근 10개 에피소드의 보상 총합의 평균을 프린트
n_epi, score/print_interval, memory.size(), epsilon*100))
score = 0.0
env.close()
if __name__ == '__main__':
main()
<aside> 📢 파이토치 설치 실패로 실구현 보류
</aside>