ruff format

This commit is contained in:
Ryan Dick
2024-06-27 09:45:13 -04:00
parent c7562dd6c0
commit 14775cc9c4
2 changed files with 10 additions and 3 deletions

View File

@ -285,7 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device))
new_dict[k] = v.to(
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
cache_entry.device = target_device