Patch LoRA on device when model is already on device.

This commit is contained in:
Ryan Dick
2023-10-31 15:39:54 -04:00
committed by Kent Keirsey
parent 545c811bf1
commit 379d68f595
3 changed files with 26 additions and 9 deletions

View File

@ -112,20 +112,34 @@ class ModelPatcher:
continue
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
# with torch.autocast(device_type="cpu"):
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
if module_key not in original_weights:
original_weights[module_key] = module.weight.to(device="cpu")
# We intentionally move to the device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
tmp_weight = module.weight.to(device=device, copy=True).to(dtype=torch.float32)
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
layer_weight = layer.get_weight(tmp_weight) * (lora_weight * layer_scale)
layer.to(device="cpu")
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
layer_weight = layer_weight.reshape(module.weight.shape)
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
module.weight = torch.nn.Parameter((tmp_weight + layer_weight).to(dtype=dtype))
yield # wait for context manager exit