Commit 91bebd22 authored by Eduard Pizur's avatar Eduard Pizur
Browse files

changed replay memory to deque

parent 969c1593
......@@ -23,7 +23,7 @@ class Agent():
self.num_of_actions = num_of_actions
self.actions = list(range(self.num_of_actions))
self.training_loss = None
self.training_loss = 0.0
# init replay memory
self.memory = ReplayMemory()
......@@ -91,8 +91,6 @@ class Agent():
'''
states, actions, next_states, rewards, dones = self.extract_batch_of_memory()
self.network.optimizer.zero_grad()
q_vals_net = self.network.forward(states)
q_vals_next_target_net = self.target_network.forward(next_states)
......@@ -106,7 +104,8 @@ class Agent():
# optimize network
loss = self.network.loss(q_vals_net, q_target)
self.network.optimizer.zero_grad()
loss.backward()
self.network.optimizer.step()
self.training_loss = loss.item()
self.training_loss += loss.item()
......@@ -63,7 +63,8 @@ if __name__ == '__main__':
learn_steps += 1
if learn_steps % 2000 == 0:
writer.add_scalar('Training loss', agent.training_loss, learn_steps)
writer.add_scalar('Training loss', agent.training_loss / 2000, learn_steps)
agent.training_loss = 0.0
state = next_state
......
import random
from collections import deque
from parameters import *
......@@ -8,10 +9,11 @@ class ReplayMemory:
'''
def __init__(self):
self.size = REPLAY_MEMORY_SIZE
self.memory = []
self.batch_size = BATCH_SIZE
self.index = 0
self.memory = deque(maxlen=self.size)
def __len__(self):
return len(self.memory)
......@@ -20,9 +22,11 @@ class ReplayMemory:
appends experience to the memory
'''
if len(self.memory) < self.size:
self.memory.append(None)
self.memory.append(experience)
return
self.memory[self.index] = experience
self.memory.insert(self.index, experience)
self.index = (self.index + 1) % self.size
def sample(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment