-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBatch_norm_predict_EDU.py
75 lines (62 loc) · 3.2 KB
/
Batch_norm_predict_EDU.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torch.nn as nn
input_size = 3
batch_size = 5
eps = 1e-1
class CustomBatchNorm1d:
def __init__(self, weight, bias, eps, momentum):
self.running_mean = torch.zeros(weight.shape[0])
self.running_var = torch.ones(weight.shape[0])
self.weight = weight
self.bias = bias
self.eps = eps
self.momentum = momentum
self.eva = False
# конструктор
def __call__(self, input_tensor):
normed_tensor = torch.zeros(input_tensor.shape) # Напишите в этом месте нормирование входного тензора.
for i in range(input_tensor.shape[1]):
avg = input_tensor[:, i].sum() / input_tensor.shape[0]
mse = ((input_tensor[:, i] - avg) ** 2).sum() / input_tensor.shape[0]
#self.running_mean[i] = avg * (1 - self.momentum) + self.momentum * self.running_mean[i]
#self.running_var[i] = mse * (1 - self.momentum) * input_tensor.shape[0] / (input_tensor.shape[0]-1) + self.momentum * self.running_var[i]
if self.eva:
normed_tensor[:, i] = (input_tensor[:,i] - self.running_mean[i]) / (self.running_var[i] + self.eps) ** 0.5 * self.weight[i] + self.bias[i]
else:
normed_tensor[:, i] = (input_tensor[:,i] - avg) / (mse + self.eps)**0.5 * self.weight[i] + self.bias[i]
self.running_mean[i] = avg * (1 - self.momentum) + self.momentum * self.running_mean[i]
self.running_var[i] = mse * (1 - self.momentum) * input_tensor.shape[0] / (
input_tensor.shape[0] - 1) + self.momentum * self.running_var[i]
return normed_tensor
# переключение в режим предикта.
def eval(self):
self.eva = True
batch_norm = nn.BatchNorm1d(input_size, eps=eps)
batch_norm.bias.data = torch.randn(input_size, dtype=torch.float)
batch_norm.weight.data = torch.randn(input_size, dtype=torch.float)
batch_norm.momentum = 0.5
custom_batch_norm1d = CustomBatchNorm1d(batch_norm.weight.data,
batch_norm.bias.data, eps, batch_norm.momentum)
# Проверка происходит автоматически вызовом следующего кода
all_correct = True
for i in range(8):
torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
norm_output = batch_norm(torch_input)
custom_output = custom_batch_norm1d(torch_input)
all_correct &= torch.allclose(norm_output, custom_output, atol=1e-06) \
and norm_output.shape == custom_output.shape
print(all_correct)
batch_norm.eval()
custom_batch_norm1d.eval()
for i in range(8):
torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
norm_output = batch_norm(torch_input)
custom_output = custom_batch_norm1d(torch_input)
all_correct &= torch.allclose(norm_output, custom_output, atol=1e-06) \
and norm_output.shape == custom_output.shape
#print("custom avg is", custom_batch_norm1d.running_mean)
#print("custom mse is", custom_batch_norm1d.running_var)
#print(batch_norm._buffers)
#print("norm_output:", norm_output)
#print( "custom:", custom_output)
print(all_correct)