diff --git a/hloc/extract_features.py b/hloc/extract_features.py index 4b9a617e..52345f15 100644 --- a/hloc/extract_features.py +++ b/hloc/extract_features.py @@ -247,6 +247,8 @@ def main(conf: Dict, return feature_path device = 'cuda' if torch.cuda.is_available() else 'cpu' + if 'device_id' in conf: + device = conf['device_id'] Model = dynamic_load(extractors, conf['model']['name']) model = Model(conf['model']).eval().to(device) diff --git a/hloc/match_features.py b/hloc/match_features.py index 74e61172..f03cf4ff 100644 --- a/hloc/match_features.py +++ b/hloc/match_features.py @@ -208,6 +208,8 @@ def match_from_paths(conf: Dict, return device = 'cuda' if torch.cuda.is_available() else 'cpu' + if 'device_id' in conf: + device = conf['device_id'] Model = dynamic_load(matchers, conf['model']['name']) model = Model(conf['model']).eval().to(device)