mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(backend): mps should not use non_blocking
We can get black outputs when moving tensors from CPU to MPS. It appears MPS to CPU is fine. See: - https://github.com/pytorch/pytorch/issues/107455 - https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28 Changes: - Add properties for each device on `TorchDevice` as a convenience. - Add `get_non_blocking` static method on `TorchDevice`. This utility takes a torch device and returns the flag to be used for non_blocking when moving a tensor to the device provided. - Update model patching and caching APIs to use this new utility. Fixes: #6545
This commit is contained in:
parent
a0a0c57789
commit
c7562dd6c0
@ -10,6 +10,7 @@ from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .raw_model import RawModel
|
||||
|
||||
@ -521,7 +522,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype, non_blocking=True)
|
||||
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
@ -285,9 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
|
||||
new_dict[k] = v.to(target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device, non_blocking=True)
|
||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
|
@ -16,6 +16,7 @@ from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .lora import LoRAModelRaw
|
||||
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
@ -139,12 +140,12 @@ class ModelPatcher:
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device, non_blocking=True)
|
||||
layer.to(dtype=torch.float32, non_blocking=True)
|
||||
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device=torch.device("cpu"), non_blocking=True)
|
||||
layer.to(device=TorchDevice.CPU_DEVICE, non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE))
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
@ -153,7 +154,7 @@ class ModelPatcher:
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
|
||||
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
@ -161,7 +162,7 @@ class ModelPatcher:
|
||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
|
||||
model.get_submodule(module_key).weight.copy_(weight, non_blocking=TorchDevice.get_non_blocking(weight.device))
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
|
@ -42,6 +42,10 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
|
||||
class TorchDevice:
|
||||
"""Abstraction layer for torch devices."""
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
CUDA_DEVICE = torch.device("cuda")
|
||||
MPS_DEVICE = torch.device("mps")
|
||||
|
||||
@classmethod
|
||||
def choose_torch_device(cls) -> torch.device:
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
@ -108,3 +112,15 @@ class TorchDevice:
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
@staticmethod
|
||||
def get_non_blocking(to_device: torch.device) -> bool:
|
||||
"""Return the non_blocking flag to be used when moving a tensor to a given device.
|
||||
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
|
||||
When moving _from_ MPS, we can use non-blocking operations.
|
||||
|
||||
See:
|
||||
- https://github.com/pytorch/pytorch/issues/107455
|
||||
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
|
||||
"""
|
||||
return False if to_device.type == "mps" else True
|
||||
|
Loading…
Reference in New Issue
Block a user