-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
42 lines (32 loc) · 1.25 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from tests.test_inference import test_inference, test_orthonormalize
from tests.test_em import test_em
from data_simulator import load_data
from Seq_Data_Class import Param_Class
from core_gpfa.plot_3d import plot_3d, plot_1d, plot_1d_error
import matplotlib.pyplot as plt
if __name__ == "__main__":
# Load data, params from sample file
INPUT_FILE = '../em_input_new.mat' # '../fake_data_w_genparams.mat'
seq = load_data(INPUT_FILE)
params = Param_Class()
params.params_from_mat(INPUT_FILE)
params.learnKernelParams = True
params.learnGPNoise = False
params.RforceDiagonal = True
# Test for em
# res = test_em(params, seq, kernSDList = 30, minVarFrac=0.01)
# Test for inference
seq, LL = test_inference(seq, params)
print("LL", LL)
print("xsm", seq[0].xsm)
# Test for orthonormalization
est_params, seq, _ = test_orthonormalize(LL, params, seq)
print("x_orth", seq[0].x_orth)
print("C_orth", est_params.C_orth)
# Test for 3d plot
plot_3d(seq, 'x_orth', dims_to_plot=[0,1,2], output_file='../test')
# Test for 1d plot
plot_1d(seq, 'x_orth', bin_width=20, output_file='../test')
# Test error plot
plot_1d_error(seq, 'x_orth', bin_width=20, output_file='../test')
# plt.show()