-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdenoising.py
116 lines (89 loc) · 3.99 KB
/
denoising.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import numpy as np
import statistics as sts
import extras.parser as parser
import extras.functions as functions
import extras.utils as utils
import perceptron.autoencoder as ae
with open("config.json") as file:
config = json.load(file)
# static non changeable vars
error_threshold: float = config["error_threshold"]
# read the files and get the dataset. There is no need to normalize data at this exercise
full_dataset, _ = parser.read_file(config["file"], config["system_threshold"])
# activation function and its derived
act_funcs = functions.get_activation_functions(config["system"], config["beta"])
# normalize data
if config["normalize"]:
full_dataset = parser.normalize_data(full_dataset)
# extract the last % of the dataset
dataset, rest = parser.extract_subset(full_dataset, config["training_ratio"])
# initializes the auto-encoder
auto_encoder = ae.AutoEncoder(*act_funcs, config["mid_layout"], len(dataset[0]), config["latent_dim"],
config["momentum"], config["alpha"])
# randomize w if asked
if bool(config["randomize_w"]):
auto_encoder.randomize_w(config["randomize_w_ref"], config["randomize_w_by_len"])
plot_bool = bool(config["plot"])
# initialize plotter
if plot_bool:
utils.init_plotter()
# get pm from config
pm: float = config["denoising"]["pm"]
# use minimizer if asked
if config["optimizer"] != "None" and config["optimizer"] != "":
# randomize the dataset
dataset = parser.randomize_data(dataset, config["data_random_seed"])
# train with minimize
auto_encoder.train_minimizer(parser.add_noise_dataset(dataset, pm), dataset, config["trust"], config["use_trust"], config["optimizer"], config["optimizer_iter"], config["optimizer_fev"])
# plot error vs opt step
utils.plot_values(range(len(auto_encoder.opt_err)), 'opt step', auto_encoder.opt_err, 'error', sci_y=False)
else:
# vars for plotting
ep_list = []
err_list = []
# train auto-encoder
for ep in range(config["epochs"]):
# randomize the dataset everytime
dataset = parser.randomize_data(dataset, config["data_random_seed"])
# train for this epoch
for data in dataset:
auto_encoder.train(parser.add_noise(data, pm), data, config["eta"])
# apply the changes
auto_encoder.update_w()
# calculate error
error: float = auto_encoder.error(parser.add_noise_dataset(dataset, pm), dataset, config["trust"], config["use_trust"])
if error < config["error_threshold"]:
break
if ep % 50 == 0:
print(f'Iteration {ep}, error {error}')
# add error to list
ep_list.append(ep)
err_list.append(error)
# plot error vs epoch
if plot_bool:
utils.plot_values(ep_list, 'epoch', err_list, 'error', sci_y=False)
# labels for printing (use with full_dataset)
labels: [] = ['@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_']
PM_ITER = 50
pm_values = [pm / 4, pm, pm * 2.5]
x_superlist = []
err_superlist = []
leg_list = ['pm=0,0625', 'pm=0,25', 'pm=0,625']
for pm_it in pm_values:
err_mean: [] = []
for data in full_dataset:
aux: [] = []
for i in range(PM_ITER):
noisy_res = auto_encoder.activation(parser.add_noise(data, pm_it))
aux.append(np.sum(abs(np.around(noisy_res[1:]) - data[1:])) / len(data[1:]))
letter_err_mean = sts.mean(aux)
err_mean.append(letter_err_mean)
x_superlist.append(range(len(full_dataset)))
err_superlist.append(err_mean)
print(f'Using pm={pm_it}, error mean is {sts.mean(err_mean)}')
if plot_bool:
utils.plot_multiple_values(x_superlist, 'Letter', err_superlist, 'Invalid bits', leg_list, sci_y=False, xticks=labels, min_val_y=0, max_val_y=1)
utils.plot_stackbars(x_superlist, 'Letter', err_superlist, 'Invalid bits', leg_list, sci_y=False, xticks=labels, min_val_y=0, max_val_y=1)
# hold execution
utils.hold_execution()