Fix type error in torch device comparison.

This commit is contained in:
Ryan Dick 2023-09-29 14:23:44 -04:00
parent 123f2b2dbc
commit 7d65555a5a

View File

@ -272,7 +272,7 @@ class ModelCache(object):
source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
# multi-GPU.
if source_device.type == target_device.type:
if torch.device(source_device).type == torch.device(target_device).type:
return
start_model_to_time = time.time()