fix(backend): mps should not use non_blocking

We can get black outputs when moving tensors from CPU to MPS. It appears MPS to CPU is fine. See:
- https://github.com/pytorch/pytorch/issues/107455
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28

Changes:
- Add properties for each device on `TorchDevice` as a convenience.
- Add `get_non_blocking` static method on `TorchDevice`. This utility takes a torch device and returns the flag to be used for non_blocking when moving a tensor to the device provided.
- Update model patching and caching APIs to use this new utility.

Fixes: #6545
This commit is contained in:
psychedelicious
2024-06-27 19:15:23 +10:00
parent a0a0c57789
commit c7562dd6c0
4 changed files with 26 additions and 8 deletions

View File

@ -10,6 +10,7 @@ from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.util.devices import TorchDevice
from .raw_model import RawModel
@ -521,7 +522,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype, non_blocking=True)
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
model.layers[layer_key] = layer
return model