-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReplayBuffer.py
71 lines (55 loc) · 2.6 KB
/
ReplayBuffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import random
import numpy as np
import torch
from collections import deque
class RandomReplayBuffer:
def __init__(self, buffer_size=10000, batch_size=32, use_conv=True, use_minimax=True):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.buffer_size = buffer_size
self.batch_size = batch_size
self.use_conv = use_conv
self.use_minimax = use_minimax
self.buffer = deque(maxlen=buffer_size)
self.start_size = 50000
def add(self, *exp):
s,*a,r,s_prime,mask,d = exp
if self.use_conv:
# input channel=3 test
# s_ = s.reshape(6,7)
# s_prime_ = s_prime.reshape(6,7)
s_ = s
s_prime_ = s_prime
else:
s_ = s.flatten()
s_prime_ = s_prime.flatten()
self.buffer.append((s_,*a,r,s_prime_,mask,d))
def get_length(self):
return len(self.buffer)
def get_maxlen(self):
return self.buffer.maxlen
def shuffle(self):
random.shuffle(self.buffer)
def sample(self):
minibatch = random.sample(self.buffer, self.batch_size)
if self.use_conv:
# state_batch.shape: (batch_size, 1, 6, 7)
s_batch = torch.stack([s1 for (s1,*a,r,s2,m,d) in minibatch]).unsqueeze(1).to(self.device)
s_prime_batch = torch.stack([s2 for (s1,*a,r,s2,m,d) in minibatch]).unsqueeze(1).to(self.device)
# input channel=3 test
s_batch = torch.stack([s1 for (s1,*a,r,s2,m,d) in minibatch]).to(self.device)
s_prime_batch = torch.stack([s2 for (s1,*a,r,s2,m,d) in minibatch]).to(self.device)
else:
# state_batch.shape: (batch_size, 42)
s_batch = torch.stack([s1 for (s1,*a,r,s2,m,d) in minibatch]).to(self.device)
s_prime_batch = torch.stack([s2 for (s1,*a,r,s2,m,d) in minibatch]).to(self.device)
# action_batch.shape: (batch_size, )
a_batch = torch.Tensor([a[0] for (s1,*a,r,s2,m,d) in minibatch]).to(self.device)
if self.use_minimax:
b_batch = torch.Tensor([a[1] for (s1,*a,r,s2,m,d) in minibatch]).to(self.device)
a_batch = 7*a_batch + b_batch
r_batch = torch.Tensor([r for (s1,*a,r,s2,m,d) in minibatch]).to(self.device)
m_batch = torch.stack([m for (s1,*a,r,s2,m,d) in minibatch]).to(self.device)
d_batch = torch.Tensor([d for (s1,*a,r,s2,m,d) in minibatch]).to(self.device)
return s_batch, a_batch, r_batch, s_prime_batch, m_batch, d_batch
def clear(self):
self.buffer.clear()