-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_ring_data.py
59 lines (51 loc) · 1.89 KB
/
run_ring_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from binary_perceptron import BinaryPerceptron # Your implementation of binary perceptron
from plot_bp import PlotBinaryPerceptron
import csv # For loading data.
from matplotlib import pyplot as plt # For creating plots.
import remapper
class PlotRingBP(PlotBinaryPerceptron):
"""
Plots the Binary Perceptron after training it on the Iris dataset
---
Extends the class PlotBinaryPerceptron
"""
def __init__(self, bp, plot_all=False, n_epochs=25, IS_REMAPPED=True):
super().__init__(bp, plot_all, n_epochs) # Calls the constructor of the super class
self.IS_REMAPPED = IS_REMAPPED
def read_data(self):
"""
Read data from the Ring dataset with 2 features and 2 classes
for both training and testing.
---
Overrides the method in PlotBinaryPerceptron
"""
data_as_strings = list(csv.reader(open('ring-data.csv'), delimiter=','))
if self.IS_REMAPPED:
self.TRAINING_DATA = [
[
remapper.remap(float(f1), float(f2))[0],
remapper.remap(float(f1), float(f2))[1],
int(c)
]
for [f1, f2, c] in data_as_strings
]
else:
self.TRAINING_DATA = [[float(f1), float(f2), int(c)] for [f1, f2, c] in data_as_strings]
self.TESTING_DATA = self.TRAINING_DATA.copy()
def plot(self):
"""
Plots the dataset as well as the binary classifier
---
Overrides the method in PlotBinaryPreceptron
"""
plt.title("Iris setosa (blue) vs iris versicolor (red)")
plt.xlabel("Sepal length")
plt.ylabel("Petal length")
plt.legend(loc='best')
plt.show()
if __name__ == '__main__':
binary_perceptron = BinaryPerceptron(alpha=0.5)
pbp = PlotRingBP(binary_perceptron)
pbp.train()
pbp.test()
pbp.plot()