Update model cache device comparison to treat 'cuda' and 'cuda:0' as the same device type.

This commit is contained in:
Ryan Dick 2023-09-29 10:34:27 -04:00
parent 1f6699ac43
commit 1e4e42556e

View File

@ -265,11 +265,13 @@ class ModelCache(object):
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
def _move_model_to_device(self, key, target_device):
def _move_model_to_device(self, key: str, target_device: torch.device):
cache_entry = self._cached_models[key]
source_device = cache_entry.model.device
if source_device == target_device:
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
# multi-GPU.
if source_device.type == target_device.type:
return
start_model_to_time = time.time()