@@ -124,25 +124,35 @@ def _pull_job_info(self):
124
124
125
125
inference_job = self ._meta_store .get_inference_job (
126
126
worker .inference_job_id )
127
+
127
128
if inference_job is None :
128
129
raise InvalidWorkerError (
129
130
'No such inference job with ID "{}"' .format (
130
131
worker .inference_job_id ))
131
- if inference_job .model_id :
132
- model = self ._meta_store .get_model (inference_job .model_id )
133
- logger .info (f'Using checkpoint of the model "{ model .name } "...' )
134
-
135
- self ._proposal = Proposal .from_jsonable ({
136
- "trial_no" : 1 ,
137
- "knobs" : {}
138
- })
139
- self ._store_params_id = model .checkpoint_id
140
- else :
141
132
142
- trial = self ._meta_store .get_trial (worker .trial_id )
143
- if trial is None or trial .store_params_id is None : # Must have model saved
133
+ trial = self ._meta_store .get_trial (worker .trial_id )
134
+
135
+ # check if there are trained model saved
136
+ if trial is None or trial .store_params_id is None :
137
+
138
+ # if there are no train job, then check if there is checkpoint uplaoded
139
+ if inference_job .model_id :
140
+ model = self ._meta_store .get_model (inference_job .model_id )
141
+ logger .info (f'Using checkpoint of the model "{ model .name } "...' )
142
+
143
+ self ._proposal = Proposal .from_jsonable ({
144
+ "trial_no" : 1 ,
145
+ "knobs" : {}
146
+ })
147
+ self ._store_params_id = model .checkpoint_id
148
+ else :
149
+
150
+ # if there is no checkpoint id and no trained model saved
144
151
raise InvalidTrialError (
145
- 'No saved trial with ID "{}"' .format (worker .trial_id ))
152
+ 'No saved trial with ID "{}" and no checkpoint uploaded' .format (worker .trial_id ))
153
+ else :
154
+
155
+ # create inference with trained parameters first
146
156
logger .info (f'Using trial "{ trial .id } "...' )
147
157
148
158
model = self ._meta_store .get_model (trial .model_id )
@@ -183,6 +193,7 @@ def _predict(self, queries: List[Query]) -> List[Prediction]:
183
193
try :
184
194
predictions = self ._model_inst .predict ([x .query for x in queries ])
185
195
except :
196
+ print ('Error while making predictions:' )
186
197
logger .error ('Error while making predictions:' )
187
198
logger .error (traceback .format_exc ())
188
199
predictions = [None for x in range (len (queries ))]
0 commit comments