Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
noblec04 committed Oct 15, 2024
1 parent 946917d commit e8d0f62
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 136 deletions.
67 changes: 67 additions & 0 deletions MatlabGP/+NN/NN.m
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,45 @@

end

function [e,de] = Batchloss(obj,V,x,y,N)

nV = length(V(:));

V = AutoDiff(V(:));

obj = obj.setHPs(V(:));

M = 0;

e = 0;
de = 0;

while size(x,1)>0
M=M+1;
itrain = randsample(size(x,1),min(N,size(x,1)));

xt = x(itrain,:);
yt = y(itrain,:);

x(itrain,:)=[];
y(itrain,:)=[];

[yp] = obj.forward(xt);

[eout] = obj.lossfunc.forward(yt,yp);

e1 = sum(eout,2);

e = e + getvalue(e1);
de1 = getderivs(e1);
de = de + reshape(full(de1),[1 nV]);
end

e = e/M;
de = de/M;

end

function [obj,fval] = train(obj,x,y,lb,ub)%,xv,fv

obj.X = x;
Expand All @@ -137,6 +176,34 @@
func = @(V) obj.loss(V,x,y);


opts = optimoptions('fmincon','SpecifyObjectiveGradient',true,'MaxFunctionEvaluations',3000,'MaxIterations',3000,'Display','final');
[theta,fval] = fmincon(func,tx0,[],[],[],[],[],[],[],opts);

%[theta,fval,xv,fv] = VSGD(func,tx0,'lr',0.01,'gamma',0.01,'iters',1000,'tol',1*10^(-7));

obj = obj.setHPs(theta(:));
end

function [obj,fval] = Batchtrain(obj,x,y,N,lb,ub)%,xv,fv

obj.X = x;
obj.Y = y;

if nargin<5
obj.lb_x = min(x);
obj.ub_x = max(x);
else
obj.lb_x = lb;
obj.ub_x = ub;
end

x = (x - obj.lb_x)./(obj.ub_x - obj.lb_x);

tx0 = (obj.getHPs());

func = @(V) obj.Batchloss(V,x,y,N);


opts = optimoptions('fmincon','SpecifyObjectiveGradient',true,'MaxFunctionEvaluations',3000,'MaxIterations',3000,'Display','final');
[theta,fval] = fmincon(func,tx0,[],[],[],[],[],[],[],opts);

Expand Down
71 changes: 0 additions & 71 deletions MatlabGP/+kernels/ICD.asv

This file was deleted.

17 changes: 17 additions & 0 deletions MatlabGP/examples/TestHetNoiseKernel.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

a = kernels.EQ(1,0.25);

H = diag(log([1 2 1 0.5 0.1 0.6 0.9 2]));

X = lhsdesign(200,2);
x = lhsdesign(8,2);

kxX = a.build(x,X);

K = diag(diag(exp(kxX'*H*kxX)));

figure(1)
clf(1)
plot3(x(:,1),x(:,2),diag(exp((H))),'x','MarkerSize',10,'LineWidth',3)
hold on
plot3(X(:,1),X(:,2),diag(K),'x','MarkerSize',10,'LineWidth',3)
21 changes: 0 additions & 21 deletions MatlabGP/examples/TestICD.asv

This file was deleted.

29 changes: 0 additions & 29 deletions MatlabGP/examples/TestICD_2D.asv

This file was deleted.

40 changes: 40 additions & 0 deletions MatlabGP/examples/testNN_2D_batch.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

clear
close all
clc

nD = 5;
nT = 1000;

xx = -2 + 4*[lhsdesign(nT,nD);utils.HypercubeVerts(nD)];

[yy] = testFuncs.Rosenbrock(xx,1)/7210;

xmesh = -2 + 4*lhsdesign(5000,nD);
ymesh = testFuncs.Rosenbrock(xmesh,1)/7210;

layers{1} = NN.FF(nD,nD);
layers{2} = NN.FF(nD,6);
layers{3} = NN.FF(6,1);

acts{1} = NN.SNAKE(8);
acts{2} = NN.SNAKE(8);

lss = NN.MSE();

nnet = NN.NN(layers,acts,lss);

%%

tic
[nnet2,fval] = nnet.Batchtrain(xx,yy,20);%,xv,fv
toc

%%

yp2 = nnet2.predict(xmesh);

%%

figure
plot(yp2,ymesh,'.')
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,33 @@
close all
clc

xx = [0;lhsdesign(10,1);1];
xx = [0;lhsdesign(1000,1);1];
yy = normrnd(forr(xx,0),0*forr(xx,0)+0.4);

yc(:,1) = double(yy>0);
yc(:,2) = double(yy<=0);


xmesh = linspace(0,1,100)';
ymesh = forr(xmesh,0);

layers{1} = NN2.FF(1,3);
layers{2} = NN2.FF(3,6);
layers{3} = NN2.FF(6,2);
layers{1} = NN.FF(1,3);
layers{2} = NN.FF(3,6);
layers{3} = NN.FF(6,3);

acts{1} = NN2.SNAKE(2);
acts{2} = NN2.SNAKE(1);
acts{1} = NN.SNAKE(2);
acts{2} = NN.SNAKE(1);

lss = NN2.CE();
lss = NN.MAE();

nnet = NN2.NN(layers,acts,lss);
nnet = NN.NN(layers,acts,lss);

%%

tic
[nnet2,fval] = nnet.train(xx,yc);%,xv,fv
[nnet2,fval] = nnet.Batchtrain(xx,yy,10);%,xv,fv
toc

%%

yp2 = nnet2.predict(xmesh);

yp3 = exp(yp2)./sum(exp(yp2));


%%
% figure
Expand Down Expand Up @@ -64,6 +58,9 @@
else
y(i,1) = (6*x(i)-2).^2.*sin(12*x(i)-4)+dx;
end

y(i,2) = 0.4*(6*x(i)-2).^2.*sin(12*x(i)-4)-x(i)-1;
y(i,3) = A*(6*x(i)-2).^2.*sin(12*x(i)-4)+B*(x(i)-0.5)-C;
end

end
Expand Down

0 comments on commit e8d0f62

Please sign in to comment.