강화학습

강화 학습: Soft Actor-Critic (SAC) 알고리즘

move84 2025. 4. 6. 09:47
반응형

강화 학습 (Reinforcement Learning, RL) 분야는 에이전트가 환경과 상호 작용하며 보상을 최대화하도록 학습하는 방법을 연구합니다. Soft Actor-Critic (SAC)은 이 분야에서 최근 각광받는 알고리즘 중 하나입니다. SAC는 안정적인 학습과 효율적인 탐색을 가능하게 하여 복잡한 환경에서도 좋은 성능을 보입니다. 이 글에서는 SAC 알고리즘의 핵심 개념, 작동 방식, 그리고 구현 예시를 자세히 살펴보겠습니다.

🚀 SAC의 기본 개념 (Basic Concepts of SAC)

SAC는 액터-크리틱 (Actor-Critic) 기반의 강화 학습 알고리즘입니다. 액터는 환경에서 행동을 선택하는 역할을 담당하고, 크리틱은 선택된 행동의 가치를 평가합니다. SAC는 여기에 엔트로피 (entropy)를 추가하여 탐색을 더욱 효과적으로 만듭니다. 엔트로피는 무작위성을 나타내는 지표로, SAC는 보상과 함께 엔트로피를 최대화하여 에이전트가 다양한 행동을 시도하도록 장려합니다. 이를 통해 에이전트는 더 넓은 범위의 행동을 탐색하고, 보다 안정적인 학습을 수행할 수 있습니다.

💡 SAC의 작동 원리 (How SAC Works)

SAC는 다음과 같은 주요 구성 요소와 단계로 작동합니다:

  1. 액터 (Actor): 액터는 현재 상태에서 어떤 행동을 할지 결정하는 정책(policy)을 표현합니다. SAC에서는 정책을 학습 가능한 신경망으로 나타내며, 행동의 확률 분포를 출력합니다.
  2. 크리틱 (Critic): 크리틱은 두 가지 역할을 수행합니다: Q-함수 (Q-function)와 가치 함수 (Value function)를 추정합니다. Q-함수는 특정 상태에서 특정 행동을 했을 때 얻을 수 있는 예상 보상을 나타내고, 가치 함수는 특정 상태에서 얻을 수 있는 예상 총 보상을 나타냅니다. SAC에서는 Q-함수와 가치 함수 모두 학습 가능한 신경망으로 표현합니다.
  3. 엔트로피 (Entropy): 엔트로피는 정책의 무작위성을 측정합니다. SAC는 보상과 엔트로피의 가중치 합을 최대화하도록 학습합니다. 이를 통해 에이전트는 탐색을 장려받고, 지역 최적해에 갇히는 것을 방지할 수 있습니다.
  4. 학습 과정 (Learning Process): SAC는 다음과 같은 단계로 학습을 진행합니다:
    • 환경에서 상태를 관찰하고, 액터가 생성한 확률 분포에 따라 행동을 선택합니다.
    • 선택된 행동을 환경에 적용하고, 보상과 다음 상태를 받습니다.
    • 경험을 저장하고, 이를 사용하여 Q-함수, 가치 함수, 그리고 액터를 업데이트합니다. 업데이트 과정에서는 손실 함수 (loss function)를 최소화하는 방식으로 진행됩니다.

🔍 SAC의 핵심 요소 (Key Elements of SAC)

  • 정책 (Policy): 에이전트가 환경에서 행동을 선택하는 방법을 결정하는 함수. SAC에서는 확률적 정책 (stochastic policy)을 사용하며, 이는 각 행동에 대한 확률 분포를 출력합니다.
  • Q-함수 (Q-function): 특정 상태에서 특정 행동을 했을 때 얻을 수 있는 예상 보상을 나타내는 함수. SAC에서는 Q-함수를 사용하여 행동의 가치를 평가합니다.
  • 가치 함수 (Value function): 특정 상태에서 얻을 수 있는 예상 총 보상을 나타내는 함수. SAC에서는 가치 함수를 사용하여 상태의 가치를 평가합니다.
  • 엔트로피 (Entropy): 정책의 무작위성을 측정하는 지표. SAC에서는 엔트로피를 보상과 함께 최대화하여 탐색을 장려합니다.
  • 온도 (Temperature, Alpha): 엔트로피의 가중치를 조절하는 하이퍼파라미터. 온도 값은 탐색과 활용 사이의 균형을 조절합니다.

💻 SAC 알고리즘의 파이썬 구현 예시 (Python Implementation Example of SAC)

다음은 SAC 알고리즘의 간단한 파이썬 구현 예시입니다. 이 예시는 기본적인 SAC 구조를 보여주기 위한 것이며, 실제 환경에서의 구현과는 차이가 있을 수 있습니다.

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 1. 신경망 정의 (Neural Network Definition)
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)
        self.tanh = nn.Tanh()

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        action = self.tanh(self.fc3(x))
        return action

class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        q_value = self.fc3(x)
        return q_value

# 2. SAC 에이전트 클래스 (SAC Agent Class)
class SACAgent:
    def __init__(self, state_dim, action_dim, lr=3e-4, gamma=0.99, tau=0.005, alpha=0.2):
        self.actor = Actor(state_dim, action_dim)
        self.critic1 = Critic(state_dim, action_dim)
        self.critic2 = Critic(state_dim, action_dim)
        self.critic1_target = Critic(state_dim, action_dim)
        self.critic2_target = Critic(state_dim, action_dim)
        self.critic1_target.load_state_dict(self.critic1.state_dict())
        self.critic2_target.load_state_dict(self.critic2.state_dict())

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=lr)
        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=lr)

        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1))
        with torch.no_grad():
            action = self.actor(state).cpu().numpy().flatten()
        return action

    def train(self, replay_buffer, batch_size=256):
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

        # 1. 크리틱 업데이트 (Critic Update)
        with torch.no_grad():
            next_actions = self.actor(next_states)
            q1_target = self.critic1_target(next_states, next_actions)
            q2_target = self.critic2_target(next_states, next_actions)
            q_target = torch.min(q1_target, q2_target)
            q_target = rewards + self.gamma * (1 - dones) * q_target

        q1_pred = self.critic1(states, actions)
        q2_pred = self.critic2(states, actions)
        critic1_loss = torch.mean((q1_pred - q_target).pow(2))
        critic2_loss = torch.mean((q2_pred - q_target).pow(2))

        self.critic1_optimizer.zero_grad()
        critic1_loss.backward()
        self.critic1_optimizer.step()
        self.critic2_optimizer.zero_grad()
        critic2_loss.backward()
        self.critic2_optimizer.step()

        # 2. 액터 업데이트 (Actor Update)
        policy_actions = self.actor(states)
        q1_policy = self.critic1(states, policy_actions)
        q2_policy = self.critic2(states, policy_actions)
        q_policy = torch.min(q1_policy, q2_policy)
        actor_loss = torch.mean(self.alpha * torch.sum(torch.log(1 - policy_actions.pow(2) + 1e-6), dim=1) - q_policy)

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 3. 타겟 네트워크 업데이트 (Target Network Update)
        for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

# 3. 리플레이 버퍼 (Replay Buffer)
class ReplayBuffer:
    def __init__(self, state_dim, action_dim, buffer_size=100000):
        self.state_buffer = np.zeros((buffer_size, state_dim))
        self.action_buffer = np.zeros((buffer_size, action_dim))
        self.reward_buffer = np.zeros(buffer_size)
        self.next_state_buffer = np.zeros((buffer_size, state_dim))
        self.done_buffer = np.zeros(buffer_size)
        self.buffer_size = buffer_size
        self.idx = 0
        self.size = 0

    def store(self, state, action, reward, next_state, done):
        self.state_buffer[self.idx] = state
        self.action_buffer[self.idx] = action
        self.reward_buffer[self.idx] = reward
        self.next_state_buffer[self.idx] = next_state
        self.done_buffer[self.idx] = done
        self.idx = (self.idx + 1) % self.buffer_size
        self.size = min(self.size + 1, self.buffer_size)

    def sample(self, batch_size):
        idx = np.random.choice(self.size, batch_size, replace=False)
        return (torch.FloatTensor(self.state_buffer[idx]),
                torch.FloatTensor(self.action_buffer[idx]),
                torch.FloatTensor(self.reward_buffer[idx].reshape(-1, 1)),
                torch.FloatTensor(self.next_state_buffer[idx]),
                torch.FloatTensor(self.done_buffer[idx].reshape(-1, 1)))

위 코드는 SAC 에이전트, 크리틱, 액터, 그리고 리플레이 버퍼를 정의합니다. Actor 클래스는 상태를 입력받아 행동을 출력하는 신경망을, Critic 클래스는 상태와 행동을 입력받아 Q-값을 출력하는 신경망을 나타냅니다. SACAgent 클래스는 액터와 크리틱을 포함하며, 학습 및 행동 선택 기능을 제공합니다. ReplayBuffer는 경험을 저장하고 샘플링하는 데 사용됩니다. 이 예시에서는 간단한 2개의 은닉층을 가진 신경망을 사용했지만, 더 복잡한 구조를 사용할 수도 있습니다. 실제 환경에 적용하기 위해서는 환경에 맞는 상태와 행동의 차원을 설정하고, 하이퍼파라미터를 조정해야 합니다.

💡 SAC 알고리즘의 장점 (Advantages of SAC)

  • 안정적인 학습 (Stable Learning): 엔트로피를 사용함으로써 탐색을 장려하고, 학습의 안정성을 높입니다.
  • 효율적인 탐색 (Efficient Exploration): 엔트로피 최대화를 통해 다양한 행동을 탐색하고, 지역 최적해에 갇히는 것을 방지합니다.
  • 높은 성능 (High Performance): 복잡한 환경에서도 좋은 성능을 보이며, 다양한 문제에 적용될 수 있습니다.

⚠️ SAC 알고리즘의 단점 (Disadvantages of SAC)

  • 하이퍼파라미터 튜닝 (Hyperparameter Tuning): 다른 강화 학습 알고리즘과 마찬가지로, SAC 역시 하이퍼파라미터 튜닝에 민감합니다. 특히, 엔트로피의 가중치를 조절하는 온도(alpha)를 적절하게 설정하는 것이 중요합니다.
  • 계산 비용 (Computational Cost): SAC는 액터와 크리틱, 그리고 타겟 네트워크를 모두 학습해야 하므로, 계산 비용이 비교적 높습니다. 특히, 복잡한 환경에서는 더 많은 계산 자원이 필요할 수 있습니다.

📚 결론 (Conclusion)

Soft Actor-Critic (SAC)은 강화 학습 분야에서 강력한 알고리즘으로 자리 잡았습니다. 안정적인 학습, 효율적인 탐색, 그리고 높은 성능을 제공하며, 다양한 환경에서 성공적으로 적용되고 있습니다. SAC의 기본 개념, 작동 원리, 그리고 구현 예시를 통해 이 알고리즘에 대한 이해를 높일 수 있습니다. SAC를 통해 강화 학습 문제를 해결하는 데 한 걸음 더 다가갈 수 있기를 바랍니다.

핵심 용어 요약 (Summary of Key Terms)

  • 강화 학습 (Reinforcement Learning): 에이전트가 환경과 상호 작용하며 보상을 최대화하도록 학습하는 방법 (agent learns to maximize reward by interacting with the environment).
  • Soft Actor-Critic (SAC): 액터-크리틱 기반의 강화 학습 알고리즘으로, 엔트로피를 사용하여 탐색을 향상시킴 (Actor-Critic based reinforcement learning algorithm that uses entropy to improve exploration).
  • 액터 (Actor): 환경에서 행동을 선택하는 정책 (policy that selects actions in the environment).
  • 크리틱 (Critic): 선택된 행동의 가치를 평가하는 함수 (function that evaluates the value of the selected action).
  • 엔트로피 (Entropy): 정책의 무작위성을 측정하는 지표 (a measure of the randomness of the policy).
  • Q-함수 (Q-function): 특정 상태에서 특정 행동을 했을 때 얻을 수 있는 예상 보상을 나타내는 함수 (function that represents the expected reward for taking a specific action in a specific state).
  • 가치 함수 (Value function): 특정 상태에서 얻을 수 있는 예상 총 보상을 나타내는 함수 (function that represents the expected total reward for a specific state).
  • 온도 (Temperature, Alpha): 엔트로피의 가중치를 조절하는 하이퍼파라미터 (hyperparameter that adjusts the weight of entropy).
  • 정책 (Policy): 에이전트가 환경에서 행동을 선택하는 방법을 결정하는 함수 (a function that determines how an agent selects actions in the environment).
  • 확률적 정책 (Stochastic Policy): 각 행동에 대한 확률 분포를 출력하는 정책 (policy that outputs a probability distribution for each action).
반응형