diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b3ebc92320..3a7d5e9e4d 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -108,13 +108,14 @@ class CompelInvocation(BaseInvocation): print(f'Warn: trigger: "{trigger}" not found') with ( - ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, ), ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers), text_encoder_info as text_encoder, + # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), ): compel = Compel( tokenizer=tokenizer, @@ -229,13 +230,14 @@ class SDXLPromptInvocationBase: print(f'Warn: trigger: "{trigger}" not found') with ( - ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, ), ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers), text_encoder_info as text_encoder, + # Apply the LoRA after text_encoder has been moved to its target device for faster patching. + ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), ): compel = Compel( tokenizer=tokenizer, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a537972c0b..56c13e6816 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -710,9 +710,10 @@ class DenoiseLatentsInvocation(BaseInvocation): ) with ( ExitStack() as exit_stack, - ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet, + # Apply the LoRA after unet has been moved to its target device for faster patching. + ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index e4f5aeb98e..eb6c50bf0d 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -112,20 +112,34 @@ class ModelPatcher: continue module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - if module_key not in original_weights: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) - # enable autocast to calc fp16 loras on cpu - # with torch.autocast(device_type="cpu"): + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + if module_key not in original_weights: + original_weights[module_key] = module.weight.to(device="cpu") + + # We intentionally move to the device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + tmp_weight = module.weight.to(device=device, copy=True).to(dtype=torch.float32) + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) layer.to(dtype=torch.float32) layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale + layer_weight = layer.get_weight(tmp_weight) * (lora_weight * layer_scale) + layer.to(device="cpu") if module.weight.shape != layer_weight.shape: # TODO: debug on lycoris layer_weight = layer_weight.reshape(module.weight.shape) - module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype) + module.weight = torch.nn.Parameter((tmp_weight + layer_weight).to(dtype=dtype)) yield # wait for context manager exit