Commit 156172d1 authored by Eduard Pizur's avatar Eduard Pizur
Browse files

rewritten dueling deep q network formula

parent 85e36b26
......@@ -126,6 +126,7 @@ class Agent():
# Q(st, at, θ)
q_vals_net = self.network.forward(states)
q_vals_net = q_vals_net.gather(1, actions.unsqueeze(-1)).squeeze(-1)
# Q(st+1, at, θ')
q_vals_next_target = self.target_network.forward(next_states)
......
......@@ -6,6 +6,7 @@ import numpy as np
from utils.constant import *
class DQN(nn.Module):
'''
cnn based on the https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf
......
......@@ -2,13 +2,13 @@ import random
import numpy as np
import torch as T
from q_network import Qnetwork
from networks.dueling_deep_q_network.q_network import Qnetwork
from utils.replay_memory import ReplayMemory
from utils.constant import *
class BaseAgent():
class Agent():
'''
Dueling DQN agent for atari games
'''
......@@ -47,7 +47,7 @@ class BaseAgent():
if random.random() > self.epsilon:
state_ = T.FloatTensor(state).unsqueeze(0).to(DEVICE)
actions = self.network.forward(state_)
action = T.argmax(actions, dim=1).item()
action = T.argmax(actions).item()
else:
action = random.choice(self.actions)
......@@ -89,24 +89,23 @@ class BaseAgent():
def train(self):
'''
train our agent using DDQN
train our agent using dueling DQN
'''
states, actions, next_states, rewards, dones = self.extract_batch_of_memory()
# predicted values
# Q(st, at, θ, α, β)
q_vals_net = self.network.forward(states)
q_vals_net = q_vals_net.gather(1, actions.unsqueeze(-1)).squeeze(-1)
q_vals_next_net = self.network.forward(next_states)
max_next_actions = T.argmax(q_vals_next_net, dim=1)
# Q(st+1, at, θ', α, β)
q_vals_next_target = self.target_network.forward(next_states)
q_vals_next_target_net = self.target_network.forward(next_states)
q_vals_next_target_net = q_vals_next_target_net.gather(
1, max_next_actions.unsqueeze(-1)).squeeze(-1)
q_vals_next_target_net[dones] = 0.0
q_vals_next_target_net = q_vals_next_target_net.detach()
# max a(Q(st+1, at, θ', α, β))
q_vals_next_target = T.max(q_vals_next_target, dim=1, keepdim=True)
q_vals_next_target[dones] = 0.0
q_target = rewards + DISCOUNT_FACTOR * q_vals_next_target_net
# rt+1 + γ*max a(Q(st+1, at, θ', α, β))
q_target = rewards + DISCOUNT_FACTOR * q_vals_next_target
# optimize network
loss = self.network.loss(q_vals_net, q_target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment