-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_lab_mlp2.m
92 lines (68 loc) · 1.81 KB
/
main_lab_mlp2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
% This is a MATLAB script for the
% CLPS1291 lab on mlps #2
% Other m-files required: none
% Subfunctions: none
% MAT-files required: none
% Author: Thomas Serre
% Brown University
% CLPS Department
% email: Thomas_Serre@Brown.edu
% Website: http://serre-lab.clps.brown.edu
% March 2014;
% Type 'help nndemos' to get a list of builtin demos
close all;
nTr = 10;
nTe = 1000;
PCA = 0;
load '../Data/mnist_all.mat'
A = [];
Y = [];
% load data
for ii = 0:9
Atmp = eval(['train' num2str(ii)]);
A = cat(1, A, Atmp);
n = size(Atmp,1);
% create labels:
% (1, 0, ..., 0) for class 1 (digit 0)
% (0, 1, ..., 0) for class 2 (digit 1)
% etc
Ytmp = zeros(1,10);
Ytmp(ii+1) = 1;
Y = cat(1, Y, repmat(Ytmp, n, 1));
end
A = double(A')/255;
siz = sqrt(size(A,2));
Y = Y';
if PCA
Ncomp = 100;
AVG = mean(A,2);
A = A - repmat(AVG, 1, size(A,2));
[PC, score, eigenvalues, tsquared, explained] = pca(A');
A = score(:,1:Ncomp)';
end
I = randperm(size(A,2));
Xtr = A(:,I(1:nTr));
Ytr = Y(:,I(1:nTr));
Xte = A(:,I(nTr+1:nTr+nTe));
Yte = Y(:,I(nTr+1:nTr+nTe));
%% Train MLP
h = [];
net.performParam.regularization = 0.5;
net = feedforwardnet(h); % initialize the network
% % net.layers{1}.transferFcn = 'tansig';
% % net.trainParam.epochs = 50;
% % net.trainParam.goal = 1e-5;
% % net.trainParam.time = 60;
% % net.trainParam.showCommandLine = 1;
% % net.trainParam.show = 1;
net = configure(net, Xtr, Ytr);
tic
net = train(net, Xtr, Ytr); % train the network
toc
%% Compute the training and test error
% Note: that the perform function does not return the
% classification error
% Note: you can also independently use the plotconfusion function
%% Visualize hidden units
%% Visualize errors
%% try parrallel toolbox