Skip to content

Commit

Permalink
fix bug when device is cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh6404 committed Jan 24, 2024
1 parent da07ccc commit 0088829
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tracking_ros/node_scripts/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def get_predictor(self):
data_cfg = get_dataset_cfg(cfg)

cutie = CUTIE(cfg).to(self.device).eval()
model_weights = torch.load(cfg.weights)
model_weights = torch.load(cfg.weights, map_location=self.device)
cutie.load_weights(model_weights)

torch.cuda.empty_cache()
if self.device.startswith("cuda"):
torch.cuda.empty_cache()
return InferenceCore(cutie, cfg=cfg)

@classmethod
Expand Down Expand Up @@ -150,7 +151,7 @@ def get_predictor(self):
# Load our checkpoint
deva_model = DEVA(cfg).to(self.device).eval()
if args.model is not None:
model_weights = torch.load(args.model)
model_weights = torch.load(args.model, map_location=self.device)
deva_model.load_weights(model_weights)
else:
print("No model loaded.")
Expand Down

0 comments on commit 0088829

Please sign in to comment.