1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Automatically generated by Colaboratory.
4
+
5
+ Original file is located at
6
+ https://colab.research.google.com/drive/1N1S8ROzSKeUhb1tI_8LKG2r2CWjMz0sT
7
+ """
8
+
9
+ import torch
10
+ import torchvision
11
+ import torchvision .transforms as transforms
12
+
13
+ transform = transforms .Compose ([transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))])
14
+
15
+ trainset = torchvision .datasets .CIFAR10 (root = '.' , train = True , download = True , transform = transform )
16
+ trainloader = torch .utils .data .DataLoader (trainset , batch_size = 4 , shuffle = True , num_workers = 8 )
17
+ testset = torchvision .datasets .CIFAR10 (root = '.' , train = False , download = True , transform = transform )
18
+ testloader = torch .utils .data .DataLoader (testset , batch_size = 4 , shuffle = False , num_workers = 8 )
19
+
20
+ classes = ('plane' , 'car' , 'bird' , 'cat' , 'deer' , 'dog' , 'frog' , 'horse' , 'ship' , 'truck' )
21
+
22
+ import matplotlib .pyplot as plt
23
+ import numpy as np
24
+
25
+ def imshow (img ):
26
+ img = img / 2 + 0.5 # Denormalize
27
+ npimg = img .numpy ()
28
+ plt .imshow (np .transpose (npimg , (1 , 2 , 0 )))
29
+ plt .show ()
30
+
31
+ # Get some random training images
32
+ dataiter = iter (trainloader )
33
+ images , labels = dataiter .next ()
34
+
35
+ # Show images
36
+ imshow (torchvision .utils .make_grid (images ))
37
+ print (' ' .join ('%5s' % classes [labels [j ]] for j in range (4 )))
38
+
39
+ # Define a CNN
40
+ import torch .nn as nn
41
+ import torch .nn .functional as F
42
+
43
+ class Net (nn .Module ):
44
+ def __init__ (self ):
45
+ super (Net , self ).__init__ ()
46
+ self .pool = nn .MaxPool2d (2 , 2 )
47
+ self .conv1 = nn .Conv2d (3 , 8 , 3 )
48
+ self .conv2 = nn .Conv2d (8 , 32 , 2 )
49
+ self .conv3 = nn .Conv2d (32 , 64 , 2 )
50
+ self .fc1 = nn .Linear (64 * 3 * 3 , 256 )
51
+ self .fc2 = nn .Linear (256 , 84 )
52
+ self .fc3 = nn .Linear (84 , 10 )
53
+ self .drop = nn .Dropout2d (p = 0.5 )
54
+
55
+ def forward (self , x ):
56
+ x = self .pool (F .relu (self .conv1 (x )))
57
+ x = self .pool (F .relu (self .conv2 (x )))
58
+ x = self .pool (F .relu (self .conv3 (x )))
59
+ x = x .view (- 1 , 64 * 3 * 3 )
60
+ x = self .drop (x )
61
+ x = F .relu (self .fc1 (x ))
62
+ x = self .drop (x )
63
+ x = F .relu (self .fc2 (x ))
64
+ x = self .drop (x )
65
+ x = self .fc3 (x )
66
+ return x
67
+
68
+ net = Net ()
69
+
70
+ # Define the loss function: classification cross-entropy loss
71
+ # Define the SGD with momentum as optimizer
72
+ import torch .optim as optim
73
+ criterion = nn .CrossEntropyLoss ()
74
+ optimizer = optim .SGD (net .parameters (), lr = 0.001 , momentum = 0.9 )
75
+
76
+ # Train the network
77
+ for epoch in range (50 ):
78
+ running_loss = 0.0
79
+ for i , data in enumerate (trainloader , 0 ):
80
+ # Get the inputs; data is a list of [inputs, labels]
81
+ inputs , labels = data
82
+
83
+ # Zero the parameter gradients
84
+ optimizer .zero_grad ()
85
+
86
+ # Forward + backward + optimize
87
+ outputs = net (inputs )
88
+ loss = criterion (outputs , labels )
89
+ loss .backward ()
90
+ if epoch == 5 :
91
+ optimizer = optim .SGD (net .parameters (), lr = 0.0001 , momentum = 0.9 )
92
+ elif epoch == 20 :
93
+ optimizer = optim .SGD (net .parameters (), lr = 0.00001 , momentum = 0.9 )
94
+ optimizer .step ()
95
+
96
+ # Print statistics
97
+ running_loss += loss .item ()
98
+ if i % 2000 == 1999 : # Print every 2000 mini-patches
99
+ print ('[%d, %5d] loss: %0.3f' % (epoch + 1 , i + 1 , running_loss / 2000 ))
100
+ running_loss = 0.0
101
+
102
+ print ('Finished Training.' )
103
+
104
+ # Save the network
105
+ ## Don't forget to change the name of the PATH!!!
106
+ # PATH = './cifar_net.pth' # default CNN structure
107
+ # PATH = './cifar_net_epoch20.pth' # default CNN structure, epoch=20
108
+ # PATH = './cifar_net_dropout_epoch20.pth' # add dropout, epoch=20
109
+ # PATH = './cifar_net_dropout2_epoch20.pth' # less dropout, epoch=20
110
+ # PATH = './cifar_net_dropout_6layer_epoch30.pth' # 6 layers, epoch=30
111
+ # PATH = './cifar_net_dropout_5layer_epoch20.pth' # 5 layers, epoch=20, feature=32*5*5
112
+ # PATH = './cifar_net_changed_rl.pth' # rl=0.001 when epoch<5, =0.0001 when epoch>=5, no dropout, epoch=20
113
+ PATH = './cifar_net_changed_rl_epoch50.pth' # rl=0.001 (<5), 0.0001 (5-20), 0.00001 (>20), epoch=50
114
+ torch .save (net .state_dict (), PATH )
115
+
116
+ # Reload the network
117
+ net = Net ()
118
+ net .load_state_dict (torch .load (PATH ))
119
+
120
+ # Test the network on the test data
121
+ dataiter = iter (testloader )
122
+ images , labels = dataiter .next ()
123
+
124
+ # Show the test images
125
+ imshow (torchvision .utils .make_grid (images ))
126
+ print ('GroundTruth: ' , ' ' .join ('%5s' % classes [labels [j ]] for j in range (4 )))
127
+
128
+ # Predicted results
129
+ outputs = net (images )
130
+ _ , predicted = torch .max (outputs , 1 )
131
+ print ('Predicted: ' , ' ' .join ('%5s' % classes [predicted [j ]] for j in range (4 )))
132
+
133
+ # To see the performance of the network on the whole training dataset
134
+ correct = 0
135
+ total = 0
136
+ with torch .no_grad ():
137
+ for data in trainloader :
138
+ images , labels = data
139
+ outputs = net (images )
140
+ _ , predicted = torch .max (outputs .data , 1 )
141
+ total += labels .size (0 )
142
+ correct += (predicted == labels ).sum ().item ()
143
+
144
+ print ('Accuracy of the network on the %d training images is: %0.2f %%' % (total , 100 * correct / total ))
145
+
146
+ # To see the performance of the network on the whole test dataset
147
+ correct = 0
148
+ total = 0
149
+ with torch .no_grad ():
150
+ for data in testloader :
151
+ images , labels = data
152
+ outputs = net (images )
153
+ _ , predicted = torch .max (outputs .data , 1 )
154
+ total += labels .size (0 )
155
+ correct += (predicted == labels ).sum ().item ()
156
+
157
+ print ('Accuracy of the network on the %d test images is: %0.2f %%' % (total , 100 * correct / total ))
158
+
159
+ # Performace of the network on each class
160
+ class_correct = list (0. for i in range (10 ))
161
+ class_total = list (0. for i in range (10 ))
162
+ with torch .no_grad ():
163
+ for data in testloader :
164
+ images , labels = data
165
+ outputs = net (images )
166
+ _ , predicted = torch .max (outputs , 1 )
167
+ c = (predicted == labels ).squeeze ()
168
+ for i in range (4 ):
169
+ label = labels [i ]
170
+ class_correct [label ] += c [i ].item ()
171
+ class_total [label ] += 1
172
+
173
+ for i in range (10 ):
174
+ print ('Accuracy of %5s: %2d %%' % (classes [i ], 100 * class_correct [i ]/ class_total [i ]))
0 commit comments