-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.py
82 lines (63 loc) · 2.23 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from network import Policy
import numpy as np
import torch
import torch.nn.functional as F
class Agent:
def __init__(
self, env_name, lr, input_dims, n_actions=4, gamma=0.99, use_cnn=False
):
self.env_name = env_name
self.lr = lr
self.input_dims = input_dims
self.n_actions = n_actions
self.gamma = gamma
self.use_cnn = use_cnn
self.reward_memory = []
self.action_memory = []
self.policy = Policy(
input_dims=self.input_dims,
n_actions=self.n_actions,
lr=self.lr,
chkpt_path=f"weights/{env_name}.pt",
use_cnn=use_cnn,
)
def choose_action(self, state):
state = torch.Tensor(np.array(state)).to(self.policy.device)
if self.use_cnn:
state = state.unsqueeze(0) # Add batch dimension for CNN
probs = F.softmax(self.policy(state), dim=-1)
action_probs = torch.distributions.Categorical(probs)
action = action_probs.sample()
log_prob = action_probs.log_prob(action)
self.action_memory.append(log_prob)
return action.detach().cpu().numpy()
def learn(self):
self.policy.optimizer.zero_grad()
Gt = np.zeros_like(self.reward_memory, dtype=np.float64)
for i in range(len(self.reward_memory)):
G_sum = 0
discount = 1
for j in range(i, len(self.reward_memory)):
G_sum += discount * self.reward_memory[j]
discount *= self.gamma
Gt[i] = G_sum
Gt = torch.Tensor(Gt).to(self.policy.device)
loss = 0
for g, logprob in zip(Gt, self.action_memory):
loss += -g * logprob
loss.backward()
self.policy.optimizer.step()
self.clear_memory()
def store_rewards(self, reward):
self.reward_memory.append(reward)
def clear_memory(self):
self.reward_memory = []
self.action_memory = []
def save_checkpoints(self):
self.policy.save_checkpoint()
def load_checkpoints(self):
self.policy.load_checkpoint()
if __name__ == "__main__":
agent = Agent(
env_name="CartPole-v1", lr=0.0005, input_dims=(8,), n_actions=4, use_cnn=False
)