From c7562dd6c0ac9ff0ecf165a0b2e4d35428738307 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 27 Jun 2024 19:15:23 +1000 Subject: [PATCH] fix(backend): mps should not use `non_blocking` We can get black outputs when moving tensors from CPU to MPS. It appears MPS to CPU is fine. See: - https://github.com/pytorch/pytorch/issues/107455 - https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28 Changes: - Add properties for each device on `TorchDevice` as a convenience. - Add `get_non_blocking` static method on `TorchDevice`. This utility takes a torch device and returns the flag to be used for non_blocking when moving a tensor to the device provided. - Update model patching and caching APIs to use this new utility. Fixes: #6545 --- invokeai/backend/lora.py | 3 ++- .../load/model_cache/model_cache_default.py | 4 ++-- invokeai/backend/model_patcher.py | 11 ++++++----- invokeai/backend/util/devices.py | 16 ++++++++++++++++ 4 files changed, 26 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index f7c3863a6a..8d17de0837 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -10,6 +10,7 @@ from safetensors.torch import load_file from typing_extensions import Self from invokeai.backend.model_manager import BaseModelType +from invokeai.backend.util.devices import TorchDevice from .raw_model import RawModel @@ -521,7 +522,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() - layer.to(device=device, dtype=dtype, non_blocking=True) + layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device)) model.layers[layer_key] = layer return model diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index d48e45426e..7331654dc1 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -285,9 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]): else: new_dict: Dict[str, torch.Tensor] = {} for k, v in cache_entry.state_dict.items(): - new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True) + new_dict[k] = v.to(target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)) cache_entry.model.load_state_dict(new_dict, assign=True) - cache_entry.model.to(target_device, non_blocking=True) + cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device)) cache_entry.device = target_device except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index fdc79539ae..993d96784a 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -16,6 +16,7 @@ from invokeai.app.shared.models import FreeUConfig from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.util.devices import TorchDevice from .lora import LoRAModelRaw from .textual_inversion import TextualInversionManager, TextualInversionModelRaw @@ -139,12 +140,12 @@ class ModelPatcher: # We intentionally move to the target device first, then cast. Experimentally, this was found to # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the # same thing in a single call to '.to(...)'. - layer.to(device=device, non_blocking=True) - layer.to(dtype=torch.float32, non_blocking=True) + layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device)) + layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device)) # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device=torch.device("cpu"), non_blocking=True) + layer.to(device=TorchDevice.CPU_DEVICE, non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE)) assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! if module.weight.shape != layer_weight.shape: @@ -153,7 +154,7 @@ class ModelPatcher: layer_weight = layer_weight.reshape(module.weight.shape) assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - module.weight += layer_weight.to(dtype=dtype, non_blocking=True) + module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device)) yield # wait for context manager exit @@ -161,7 +162,7 @@ class ModelPatcher: assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight, non_blocking=True) + model.get_submodule(module_key).weight.copy_(weight, non_blocking=TorchDevice.get_non_blocking(weight.device)) @classmethod @contextmanager diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index e8380dc8bc..1cba70c662 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -42,6 +42,10 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA class TorchDevice: """Abstraction layer for torch devices.""" + CPU_DEVICE = torch.device("cpu") + CUDA_DEVICE = torch.device("cuda") + MPS_DEVICE = torch.device("mps") + @classmethod def choose_torch_device(cls) -> torch.device: """Return the torch.device to use for accelerated inference.""" @@ -108,3 +112,15 @@ class TorchDevice: @classmethod def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype: return NAME_TO_PRECISION[precision_name] + + @staticmethod + def get_non_blocking(to_device: torch.device) -> bool: + """Return the non_blocking flag to be used when moving a tensor to a given device. + MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS. + When moving _from_ MPS, we can use non-blocking operations. + + See: + - https://github.com/pytorch/pytorch/issues/107455 + - https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28 + """ + return False if to_device.type == "mps" else True