Deep Q-Network 알고리즘

2025. 2. 5. 21:31·AI/Deep Reinforcement Learning

1. 개요

DQN 은 강화학습 의 대표적인 알고리즘 중 하나로, 기존 Q-learning을 Neural Network과 결합하여 거대한 상태 공간에서도 학습할 수 있도록 만든 기법이다.

2013년, 구글 딥마인드  연구진이 Playing Atari with Deep Reinforcement Learning 논문을 공개하면서, 아타리-Atari 게임에서 DQN을 적용해 인간 수준의 성능을 달성하며 주목받았다.
DQN의 핵심 아이디어는 딥러닝을 활용하여 Q-value를 근사하고, 안정적인 학습을 위한 몇 가지 추가적인 기법을 적용하는 것이다.


2. Q-learning 복습

 Q-learning이란?

Q-learning은 각 State에서 최적의 Action을 선택하는 방법을 학습하는 강화학습 알고리즘이다.
핵심 개념은 Q-value를 업데이트하는 Bellman Equation 에 기반한다.

$$Q(s, a) \leftarrow Q(s, a) + \alpha \big( r + \gamma \max_{a'} Q(s', a') - Q(s, a) \big)$$

  • s : 현재 상태 State

a : 현재 행동 Action

  • r : 보상 Reward
  • s' : 다음 상태 Next State
  • $\gamma$ : 할인율 Discount Factor
  • $\alpha$ : 학습률 Learning Rate

Q-learning의 문제점

  1. 상태 공간이 커지면 Q-table이 너무 커짐 → 저장 및 업데이트가 어려움.
  2. 보지 못한 상태에서는 Q-value를 예측할 수 없음 → generalization 부족.
  3. 연속적인 상태 공간에서는 적용이 어려움.

3. DQN의 핵심 아이디어

DQN은 위의 문제를 해결하기 위해 Q-value를 Deep Neural Network, DNN 으로 근사하는 방법을 사용한다.

DQN의 핵심 아이디어
Q-learning에서 사용하는 Q-table을 신경망으로 대체하여 근사- fitting하고, 안정적인 학습을 위해 몇 가지 추가 기법을 적용한다.

DQN의 주요 구성 요소

  1. 경험 재생-Experience Replay
  2. 타겟 네트워크-Target Network
  3. 신경망 기반 Q-value 근사
  4. 입출력 구조 및 손실 함수 정의
  5. 탐색과 활용의 균형-Exploration vs. Exploitation, ε-greedy 사용

4. DQN의 주요 기법

(1) 경험 재생-Experience Replay

강화학습에서는 에이전트가 환경과 상호작용하면서 데이터(경험)를 수집한다.
전통적인 Q-learning에서는 이 경험을 즉시 학습에 반영하지만, DQN에서는 이를 Replay Buffer 에 저장하고, 이후 랜덤 샘플링하여 학습한다.

  • 샘플 간 Correlation를 줄여 학습 안정성을 높인다.
  • 특정 보상이 적게 발생하는 경우에도 학습할 기회를 증가시킨다.
  • 데이터를 독립적으로 샘플링함으로써 비효율적인 학습을 방지한다.

어떻게 구현되는가?

  1. (s, a, r, s') 데이터를 저장하는 Replay Buffer 생성.
  2. 일정 시간 동안 경험을 버퍼에 저장.
  3. 학습할 때마다 버퍼에서 랜덤 샘플을 추출하여 신경망 업데이트.

 

(2) 타겟 네트워크-Target Network 

DQN에서 Q-value를 학습할 때 불안정한 학습 문제가 발생할 수 있다.
이를 해결하기 위해 두 개의 신경망을 사용한다.

  • Q-Network (현재 네트워크): 행동을 선택하는 메인 신경망.
  • Target Q-Network (타겟 네트워크): Q-value 업데이트에 사용되는 신경망.

타겟 네트워크의 역할

  • 학습 대상이 되는 Q-value를 계산할 때, 오래된 Q-network를 사용하여 업데이트를 부드럽게 만든다.
  • 타겟 네트워크는 일정 간격마다 현재 네트워크의 가중치를 복사하여 업데이트된다.

$$
y = r + \gamma \max_{a'} Q_{\text{target}}(s', a'; \theta^-)
$$
$$
\theta^- \leftarrow \theta \quad (\text{일정 주기마다 업데이트})
$$

  • $\theta$ : 현재 네트워크의 가중치
  • $\theta^-$ : 타겟 네트워크의 가중치

 

(3) 신경망 기반 Q-value 근사

기존 Q-learning에서는 Q-table을 사용했지만, DQN에서는 Q-value를 신경망으로 근사한다.

신경망 구조

  • 입력-Input: 현재 상태)
  • 출력-Output: 각 행동에 대한 Q-value
  • 손실 함수-Loss function:
    $$
    L(\theta) = \mathbb{E} \big[ (y - Q(s, a; \theta))^2 \big]
    $$
    • 여기서 $ y = r + \gamma \max Q_{\text{target}}(s', a') $

 

Q-value 업데이트

$$
\theta \leftarrow \theta - \alpha \nabla_{\theta} L(\theta)
$$

 

(4) 경험 재생을 활용한 배치 업데이트

DQN은 단순한 온라인 업데이트 대신, Experience Replay 기법을 활용한 배치 업데이트를 사용한다.

 

배치 업데이트 수식

$$
w_{t+1} \leftarrow w_t + \alpha \frac{1}{N} \sum_{i=1}^{N} \left[ r_i + \left( (1-done_i) . \gamma . \max_{a'} \hat{q}(s'{i},a';w^{-}{t}) \right) – \hat{q}(s_i,a_i;w_t) \right] \nabla \hat{q}(s_i,a_i;w_t)
$$

  • $N $: 배치 크기.
  • $ (s_i, a_i, r_i, s'_i, done_i)$: 경험 메모리에서 샘플링한 데이터.

 

(5) 탐색과 활용 (Exploration vs. Exploitation)

강화학습에서는 탐색-Exploration과 활용-Exploitation의 균형이 중요하다.

  • Exploration: 새로운 행동을 시도하여 더 좋은 보상을 찾는 과정.
  • Exploitation: 현재까지의 학습을 활용하여 최적 행동을 선택.

DQN에서는 ε-greedy 정책을 사용하여 탐색과 활용을 조절한다.

$$
P(a) =
\begin{cases}
1 - \epsilon & \text{(현재 Q-value가 최대인 행동)} \
\epsilon & \text{(랜덤한 행동)}
\end{cases}
$$

  • 초기에는 ε를 크게 설정하여 랜덤 행동을 많이 시도.
  • 학습이 진행될수록 ε을 점진적으로 감소시켜 최적 정책을 학습.

5. DQN 학습 과정

1️⃣ 환경에서 행동 수행 및 경험 저장

  • 현재 상태 (s)에서 행동 (a)를 선택하고, 보상 (r)과 다음 상태 (s')를 저장.

2️⃣ Experience Replay

  • Replay Buffer에서 랜덤 샘플을 추출하여 학습 데이터로 사용.

3️⃣ Q-network 업데이트

  • 타겟 Q-value를 계산하고, 손실함수(loss)를 줄이도록 신경망을 업데이트.

4️⃣ 타겟 네트워크 업데이트

  • 일정 주기마다 타겟 네트워크를 현재 네트워크로 갱신.

5️⃣ 반복 학습

  • 위 과정을 반복하며 최적의 Q-value를 학습.

6. 코드 구현

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import deque

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def add(self, experience):
        self.buffer.append(experience)
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(states),
            torch.LongTensor(actions).unsqueeze(1),
            torch.FloatTensor(rewards),
            torch.FloatTensor(next_states),
            torch.FloatTensor(dones),
        )
    
    def __len__(self):
        return len(self.buffer)

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)  # Q-values for each action

class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01, memory_size=10000, batch_size=64):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size
        
        self.q_network = DQN(state_dim, action_dim)
        self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
        self.criterion = nn.MSELoss()
        self.memory = ReplayBuffer(capacity=memory_size)
    
    def act(self, state):
        if random.random() < self.epsilon:
            return np.random.randint(self.action_dim)
        state = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            return torch.argmax(self.q_network(state)).item()
    
    def update_memory(self, experience):
        self.memory.add(experience)
    
    def train(self):
        if len(self.memory) < self.batch_size:
            return
        
        states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
        
        q_values = self.q_network(states).gather(1, actions).squeeze(1)
        next_q_values = self.q_network(next_states).max(dim=1)[0]
        target_q_values = rewards + (self.gamma * next_q_values * (1 - dones))
        
        loss = self.criterion(q_values, target_q_values.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

 


6. 결론

DQN은 Q-learning을 신경망으로 확장하여 거대한 상태 공간에서도 학습할 수 있도록 개선된 알고리즘이다.
경험 재생과 타겟 네트워크를 도입하여 학습 안정성을 높였으며, 이후 다양한 개선 기법, Double DQN, Dueling DQN 등 으로 발전했다.

핵심 요약

  • Q-table 대신 신경망으로 Q-value를 근사.
  • 경험 재생과 타겟 네트워크로 학습 안정성 향상.
  • 게임 AI, 로보틱스, 금융 트레이딩 등 다양한 분야에서 활용.

DQN의 발전은 Deep Reinforcement Learning 의 가능성을 열어주었으며, 이후 다양한 알고리즘의 기반이 되었다.

 

 

'AI > Deep Reinforcement Learning' 카테고리의 다른 글

DQN의 개선 - 우선순위 경험재헌, DDQN, Dueling-DQN  (0) 2025.02.28
'AI/Deep Reinforcement Learning' 카테고리의 다른 글
  • DQN의 개선 - 우선순위 경험재헌, DDQN, Dueling-DQN
Juson
Juson
  • Juson
    Juson의 데이터 공부
    Juson
  • 전체
    오늘
    어제
    • 분류 전체보기 (95)
      • RAG (2)
      • AI (2)
        • NLP (0)
        • Generative Model (0)
        • Deep Reinforcement Learning (2)
        • LLM (0)
      • Logistic Optimization (0)
      • Machine Learning (37)
        • Linear Regression (2)
        • Logistic Regression (2)
        • Decision Tree (5)
        • Naive Bayes (1)
        • KNN (2)
        • SVM (2)
        • Clustering (4)
        • Dimension Reduction (3)
        • Boosting (6)
        • Abnomaly Detection (2)
        • Recommendation (4)
        • Embedding & NLP (4)
      • Reinforcement Learning (5)
      • Deep Learning (10)
        • Deep learning Bacis Mathema.. (10)
      • Optimization (2)
        • OR Optimization (0)
        • Convex Optimization (0)
        • Integer Optimization (0)
      • SNA 분석 (0)
      • 포트폴리오 최적화 공부 (0)
        • 최적화 기법 (0)
        • 금융 베이스 (0)
      • Finanancial engineering (0)
      • 프로그래머스 데브코스(Boot camp) (15)
        • SQL (9)
        • Python (5)
        • Machine Learning (1)
      • Python (22)
      • Project (0)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 링크

  • 공지사항

  • 인기 글

  • 태그

  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.4
Juson
Deep Q-Network 알고리즘
상단으로

티스토리툴바