diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 7331654dc1..697d3daf9b 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -285,7 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]): else: new_dict: Dict[str, torch.Tensor] = {} for k, v in cache_entry.state_dict.items(): - new_dict[k] = v.to(target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)) + new_dict[k] = v.to( + target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device) + ) cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device)) cache_entry.device = target_device diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 993d96784a..051d114276 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -145,7 +145,10 @@ class ModelPatcher: # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to(device=TorchDevice.CPU_DEVICE, non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE)) + layer.to( + device=TorchDevice.CPU_DEVICE, + non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE), + ) assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! if module.weight.shape != layer_weight.shape: @@ -162,7 +165,9 @@ class ModelPatcher: assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_(weight, non_blocking=TorchDevice.get_non_blocking(weight.device)) + model.get_submodule(module_key).weight.copy_( + weight, non_blocking=TorchDevice.get_non_blocking(weight.device) + ) @classmethod @contextmanager