mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(backend): mps should not use non_blocking
We can get black outputs when moving tensors from CPU to MPS. It appears MPS to CPU is fine. See: - https://github.com/pytorch/pytorch/issues/107455 - https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28 Changes: - Add properties for each device on `TorchDevice` as a convenience. - Add `get_non_blocking` static method on `TorchDevice`. This utility takes a torch device and returns the flag to be used for non_blocking when moving a tensor to the device provided. - Update model patching and caching APIs to use this new utility. Fixes: #6545
This commit is contained in:
@ -285,9 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
|
||||
new_dict[k] = v.to(target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device, non_blocking=True)
|
||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
|
Reference in New Issue
Block a user