Commit 969c1593 authored by Eduard Pizur's avatar Eduard Pizur
Browse files

updated train function

parent 974da163
......@@ -99,6 +99,7 @@ class Agent():
q_vals_net = q_vals_net.gather(1, actions.unsqueeze(-1)).squeeze(-1)
q_vals_next_target_net = q_vals_next_target_net.max(1)[0]
q_vals_next_target_net[dones] = 0.0
q_vals_next_target_net = q_vals_next_target_net.detach()
q_target = rewards + DISCOUNT_FACTOR * q_vals_next_target_net
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment