mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ruff format
This commit is contained in:
parent
c7562dd6c0
commit
14775cc9c4
@ -285,7 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
else:
|
else:
|
||||||
new_dict: Dict[str, torch.Tensor] = {}
|
new_dict: Dict[str, torch.Tensor] = {}
|
||||||
for k, v in cache_entry.state_dict.items():
|
for k, v in cache_entry.state_dict.items():
|
||||||
new_dict[k] = v.to(target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device))
|
new_dict[k] = v.to(
|
||||||
|
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
|
||||||
|
)
|
||||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||||
cache_entry.device = target_device
|
cache_entry.device = target_device
|
||||||
|
@ -145,7 +145,10 @@ class ModelPatcher:
|
|||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE, non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE))
|
layer.to(
|
||||||
|
device=TorchDevice.CPU_DEVICE,
|
||||||
|
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||||
if module.weight.shape != layer_weight.shape:
|
if module.weight.shape != layer_weight.shape:
|
||||||
@ -162,7 +165,9 @@ class ModelPatcher:
|
|||||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for module_key, weight in original_weights.items():
|
||||||
model.get_submodule(module_key).weight.copy_(weight, non_blocking=TorchDevice.get_non_blocking(weight.device))
|
model.get_submodule(module_key).weight.copy_(
|
||||||
|
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
Loading…
Reference in New Issue
Block a user