mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Patch LoRA on device when model is already on device.
This commit is contained in:
parent
545c811bf1
commit
379d68f595
@ -108,13 +108,14 @@ class CompelInvocation(BaseInvocation):
|
|||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
|
|
||||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
|
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||||
):
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -229,13 +230,14 @@ class SDXLPromptInvocationBase:
|
|||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
|
|
||||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
|
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||||
):
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -710,9 +710,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
|
||||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||||
unet_info as unet,
|
unet_info as unet,
|
||||||
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
|
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||||
):
|
):
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
|
@ -112,20 +112,34 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||||
if module_key not in original_weights:
|
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
|
||||||
|
|
||||||
# enable autocast to calc fp16 loras on cpu
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
# with torch.autocast(device_type="cpu"):
|
# (Performance will be best if this is a CUDA device.)
|
||||||
|
device = module.weight.device
|
||||||
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
|
if module_key not in original_weights:
|
||||||
|
original_weights[module_key] = module.weight.to(device="cpu")
|
||||||
|
|
||||||
|
# We intentionally move to the device first, then cast. Experimentally, this was found to
|
||||||
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
|
# same thing in a single call to '.to(...)'.
|
||||||
|
tmp_weight = module.weight.to(device=device, copy=True).to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
|
# same thing in a single call to '.to(...)'.
|
||||||
|
layer.to(device=device)
|
||||||
layer.to(dtype=torch.float32)
|
layer.to(dtype=torch.float32)
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
|
layer_weight = layer.get_weight(tmp_weight) * (lora_weight * layer_scale)
|
||||||
|
layer.to(device="cpu")
|
||||||
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
if module.weight.shape != layer_weight.shape:
|
||||||
# TODO: debug on lycoris
|
# TODO: debug on lycoris
|
||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||||
|
|
||||||
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
module.weight = torch.nn.Parameter((tmp_weight + layer_weight).to(dtype=dtype))
|
||||||
|
|
||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user