Replay Buffer
class ReplayBuffer:
def __init__(self, buffer_size, state_dim, action_dim, device="cpu"):
self.buffer_size = buffer_size
self.state_dim = state_dim
self.action_dim = action_dim
self.device = device
self.observations = np.zeros((buffer_size, *state_dim), dtype=np.uint8)
self.next_observations = np.zeros((buffer_size, *state_dim), dtype=np.uint8)
self.actions = np.zeros((buffer_size, action_dim), dtype=np.int64)
self.rewards = np.zeros((buffer_size,), dtype=np.float32)
self.dones = np.zeros((buffer_size,), dtype=np.float32)
self.pos = 0
self.full = False
def add(self, obs, next_obs, action, reward, done):
self.observations[self.pos] = obs
self.next_observations[self.pos] = next_obs
self.actions[self.pos] = action
self.rewards[self.pos] = reward
self.dones[self.pos] = done
self.pos = (self.pos + 1) % self.buffer_size
if self.pos == 0:
self.full = True
def sample(self, batch_size):
total = self.buffer_size if self.full else self.pos
indices = np.random.choice(total, batch_size, replace=False)
obs_batch = self.observations[indices]
next_obs_batch = self.next_observations[indices]
actions_batch = self.actions[indices]
rewards_batch = self.rewards[indices]
dones_batch = self.dones[indices]
return (
torch.tensor(obs_batch, dtype=torch.uint8, device=self.device),
torch.tensor(actions_batch, dtype=torch.int64, device=self.device),
torch.tensor(next_obs_batch, dtype=torch.uint8, device=self.device),
torch.tensor(rewards_batch, dtype=torch.float32, device=self.device).unsqueeze(1),
torch.tensor(dones_batch, dtype=torch.float32, device=self.device).unsqueeze(1),
)
- Maximum capacity 만큼의 np.zeros 공간을 할당하여 사용한다.
- deque 같은 파이썬 내장 라이브러리 사용하면 CPU 과부하 온다.
- sample의 경우 버퍼 사이즈 내에서 sample을 pick할 수 있게 한다.
Q-network
class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
nn.ReLU(),
nn.Linear(512, env.single_action_space.n),
)
def forward(self, x):
return self.network(x / 255.0)
- Input으로 (배치 사이즈, 채널, width, height) 크기의 이미지 입력이 들어가는 네트워크 구조이다.
- $$
\text { Output size }=\frac{\text { input size }- \text { filter size }+(2 * \text { padding })}{\text { Stride }}+1
$$ - nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
Input: (32, 4, 84, 84) # 배치 크기 32, 채널 4, 이미지 크기 84x84
│
└─ Conv2d(4, 32, kernel_size=8, stride=4) → (32, 32, 20, 20)
# 4채널 입력 → 32채널, 커널 8x8, 스트라이드 4
│
└─ Conv2d(32, 64, kernel_size=4, stride=2) → (32, 64, 9, 9)
# 32채널 입력 → 64채널, 커널 4x4, 스트라이드 2
│
└─ Conv2d(64, 64, kernel_size=3, stride=1) → (32, 64, 7, 7)
# 64채널 유지, 커널 3x3, 스트라이드 1
│
└─ Flatten → (32, 3136)
# 64 × 7 × 7 = 3136
│
└─ Linear(3136, 512) → (32, 512)
│
└─ Linear(512, 4) → (32, 4)
# 출력은 액션 수 (예: Breakout에서는 4개)
$\epsilon$-Greedy
def linear_schedule(start_e, end_e, duration, t):
slope = (end_e - start_e) / duration
return max(slope * t + start_e, end_e)
epsilon = linear_schedule(start_e, end_e, exploration_fraction * total_timesteps, global_step)
if random.random() < epsilon:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
q_values = q_network(torch.Tensor(state).to(device))
actions = torch.argmax(q_values, dim=1).cpu().numpy()
next_state, rewards, terminations, truncations, infos = envs.step(actions)
- random을 통해 epsilon보다 작다면, 4개의 action 중 무작위를 선택
- random을 통해 epsilon보다 크다면, Behavior Qnetwork값 중 가장 큰 값을 가지게 하는 action을 greedy하게 selection
Train
obs, actions, next_obs, rewards, dones = rb.sample(batch_size)
with torch.no_grad():
target_max = target_network(next_obs).max(1)[0]
td_target = rewards.flatten() + gamma * target_max * (1 - dones.flatten())
q_values = q_network(obs)
q_action = q_values.gather(1, actions).squeeze()
td_error = td_target - q_action
loss = (td_error ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
- Replay memory에서 batch size만큼의 sample을 추출한다.
- $$
L_i\left(\theta_i\right)=\mathbb{E}_{s, a, s^{\prime}, r \sim D}(\underbrace{r+\gamma \max _{a^{\prime}} Q\left(s^{\prime}, a^{\prime} ; \theta_i^{-}\right)}_{\text {target }}-Q\left(s, a ; \theta_i\right))^2
$$
Result
- Breakout 환경에서 1000만 step 학습한 결과이다.
- 굉장히 Classic한 dqn이기에 아직 비교 대상이 없다.
전체 코드
import os
import random
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireResetEnv,
MaxAndSkipEnv,
NoopResetEnv,
)
def make_env(env_id, seed, idx, capture_video):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
print(env.observation_space)
env.action_space.seed(seed)
return env
return thunk
class QNetwork(nn.Module):
def __init__(self, env):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(4, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
nn.ReLU(),
nn.Linear(512, env.single_action_space.n),
)
def forward(self, x):
return self.network(x / 255.0)
def linear_schedule(start_e, end_e, duration, t):
slope = (end_e - start_e) / duration
return max(slope * t + start_e, end_e)
class ReplayBuffer:
def __init__(self, buffer_size, state_dim, action_dim, device="cpu"):
self.buffer_size = buffer_size
self.state_dim = state_dim
self.action_dim = action_dim
self.device = device
self.observations = np.zeros((buffer_size, *state_dim), dtype=np.uint8)
self.next_observations = np.zeros((buffer_size, *state_dim), dtype=np.uint8)
self.actions = np.zeros((buffer_size, action_dim), dtype=np.int64)
self.rewards = np.zeros((buffer_size,), dtype=np.float32)
self.dones = np.zeros((buffer_size,), dtype=np.float32)
self.pos = 0
self.full = False
def add(self, obs, next_obs, action, reward, done):
self.observations[self.pos] = obs
self.next_observations[self.pos] = next_obs
self.actions[self.pos] = action
self.rewards[self.pos] = reward
self.dones[self.pos] = done
self.pos = (self.pos + 1) % self.buffer_size
if self.pos == 0:
self.full = True
def sample(self, batch_size):
total = self.buffer_size if self.full else self.pos
indices = np.random.choice(total, batch_size, replace=False)
obs_batch = self.observations[indices]
next_obs_batch = self.next_observations[indices]
actions_batch = self.actions[indices]
rewards_batch = self.rewards[indices]
dones_batch = self.dones[indices]
return (
torch.tensor(obs_batch, dtype=torch.uint8, device=self.device),
torch.tensor(actions_batch, dtype=torch.int64, device=self.device),
torch.tensor(next_obs_batch, dtype=torch.uint8, device=self.device),
torch.tensor(rewards_batch, dtype=torch.float32, device=self.device).unsqueeze(1),
torch.tensor(dones_batch, dtype=torch.float32, device=self.device).unsqueeze(1),
)
if __name__ == "__main__":
model_path = "./model"
os.makedirs(model_path, exist_ok=True)
seed = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env_name = "BreakoutNoFrameskip-v4"
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
learning_rate = 1e-4
buffer_size = 100000
total_timesteps = 10000000
start_e = 1.0
end_e = 0.1
exploration_fraction = 0.1
learning_starts = 80000
train_frequency = 4
batch_size = 32
gamma = 0.99
target_network_frequency = 1000
tau = 1.0
episode = 0
use_wandb = False
if use_wandb:
import wandb
wandb.init(
project="dqn-breakout",
config={
"env_name": env_name,
"total_timesteps": total_timesteps,
"learning_rate": learning_rate,
"buffer_size": buffer_size,
"batch_size": batch_size,
"gamma": gamma,
"start_e": start_e,
"end_e": end_e,
"exploration_fraction": exploration_fraction,
"train_frequency": train_frequency,
"learning_starts": learning_starts,
"target_network_frequency": target_network_frequency,
"tau": tau,
"seed": seed,
},
)
envs = gym.vector.SyncVectorEnv(
[make_env(env_name, seed + i, i, False) for i in range(1)]
)
q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)
target_network = QNetwork(envs).to(device)
target_network.load_state_dict(q_network.state_dict())
obs_shape = envs.single_observation_space.shape
action_shape = (1,) # Discrete 환경일 경우
rb = ReplayBuffer(
buffer_size=buffer_size,
state_dim=obs_shape,
action_dim=action_shape[0],
device=device
)
state, _ = envs.reset(seed=seed)
for global_step in range(total_timesteps):
epsilon = linear_schedule(start_e, end_e, exploration_fraction * total_timesteps, global_step)
if random.random() < epsilon:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
q_values = q_network(torch.Tensor(state).to(device))
actions = torch.argmax(q_values, dim=1).cpu().numpy()
next_state, rewards, terminations, truncations, infos = envs.step(actions)
if "final_info" in infos:
for info in infos["final_info"]:
if info and "episode" in info:
episode += 1
print(f"steps:{global_step}, episode:{episode}, reward:{info['episode']['r']}, stepLength:{info['episode']['l']}")
if use_wandb:
wandb.log(
{
"episode": episode,
"episodic_return": info["episode"]["r"],
"episodic_length": info["episode"]["l"],
"epsilon": epsilon,
"global_step": global_step,
}
)
real_next_state = next_state.copy()
for idx, trunc in enumerate(truncations):
if trunc:
real_next_state[idx] = infos["final_observation"][idx]
rb.add(state, real_next_state, actions, rewards, terminations)
state = next_state
if global_step > learning_starts:
if global_step % train_frequency == 0:
obs, actions, next_obs, rewards, dones = rb.sample(batch_size)
with torch.no_grad():
target_max = target_network(next_obs).max(1)[0]
td_target = rewards.flatten() + gamma * target_max * (1 - dones.flatten())
q_values = q_network(obs)
q_action = q_values.gather(1, actions).squeeze()
td_error = td_target - q_action
loss = (td_error ** 2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if use_wandb:
wandb.log({
"loss": loss.item(),
"global_step": global_step
})
if global_step % target_network_frequency == 0:
for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
target_network_param.data.copy_(
tau * q_network_param.data + (1.0 - tau) * target_network_param.data
)
if episode % 1000 == 0:
model_file = os.path.join(model_path, f"Breakout_dqn_classic_{episode}.pth")
torch.save(q_network.state_dict(), model_file)
print(f"✅ Saved model at episode {episode} to {model_file}")
envs.close()
'Coding' 카테고리의 다른 글
Dueling DQN 코드 구현 (0) | 2025.03.31 |
---|---|
Double DQN (DDQN) 코드 구현 (0) | 2025.03.31 |
Prioritized Experience Replay (PER) 코드 구현 (0) | 2025.03.31 |