@@ -60,3 +60,34 @@ def test_eval(self):
60
60
# no element in the diff array should be larger than 1e-7
61
61
maxdiff = np .max (np .abs (rtf - rtd ))
62
62
self .assertLess (maxdiff , 1e-7 )
63
+
64
+ def test_ensemble_eval (self ):
65
+ simple_model2 = td .Model ()
66
+ y2 , sess2 = self .get ("simple2" , "y" , "sess" )
67
+ simple_model2 .add (y2 , tf_sess = sess2 )
68
+
69
+ simple_model2 .get ("input_1" ).name = "input:0"
70
+ simple_model2 .get ("output_1" ).name = "output:0"
71
+ simple_model2 .get ("keep_prob_1" ).name = "keep_prob:0"
72
+
73
+ simple_ensemble = td .Ensemble ()
74
+ simple_ensemble .models = [self .simple_model , simple_model2 ]
75
+
76
+ inp , outp , kp = simple_ensemble .get ("input" , "output" , "keep_prob" )
77
+
78
+ # create an input batch
79
+ examples = np .random .rand (1000 , 10 ).astype ("float32" )
80
+
81
+ # eval both models manually and build the mean
82
+ x1 , y1 , keep_prob1 = self .simple_model .get ("input" , "output" , "keep_prob" )
83
+ r1 = y1 .eval ({x1 : examples , keep_prob1 : 1.0 })
84
+ x2 , y2 , keep_prob2 = simple_model2 .get ("input" , "output" , "keep_prob" )
85
+ r2 = y2 .eval ({x2 : examples , keep_prob2 : 1.0 })
86
+ rm = np .add (r1 , r2 ) / 2.
87
+
88
+ # then, eval the ensemble
89
+ re = outp .eval ({inp : examples , kp : 1.0 })
90
+
91
+ # no element in the diff array should be larger than 1e-7
92
+ maxdiff = np .max (np .abs (re - rm ))
93
+ self .assertLess (maxdiff , 1e-7 )
0 commit comments