ruff format

This commit is contained in:
Ryan Dick 2024-06-27 09:45:13 -04:00
parent c7562dd6c0
commit 14775cc9c4
2 changed files with 10 additions and 3 deletions

View File

@ -285,7 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
else: else:
new_dict: Dict[str, torch.Tensor] = {} new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items(): for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)) new_dict[k] = v.to(
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
)
cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device)) cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
cache_entry.device = target_device cache_entry.device = target_device

View File

@ -145,7 +145,10 @@ class ModelPatcher:
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=TorchDevice.CPU_DEVICE, non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE)) layer.to(
device=TorchDevice.CPU_DEVICE,
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
@ -162,7 +165,9 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad(): with torch.no_grad():
for module_key, weight in original_weights.items(): for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight, non_blocking=TorchDevice.get_non_blocking(weight.device)) model.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)
@classmethod @classmethod
@contextmanager @contextmanager