From e03eb7fb45c6a24fee83e30db4d439ce76383ef9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 10 Jun 2024 11:40:46 -0400 Subject: [PATCH] Add support for LoRA models in TiledStableDiffusionRefineInvocation. --- .../invocations/tiled_stable_diffusion_refine.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/tiled_stable_diffusion_refine.py b/invokeai/app/invocations/tiled_stable_diffusion_refine.py index a42d531241..592a00f71c 100644 --- a/invokeai/app/invocations/tiled_stable_diffusion_refine.py +++ b/invokeai/app/invocations/tiled_stable_diffusion_refine.py @@ -1,4 +1,5 @@ from contextlib import ExitStack +from typing import Iterator, Tuple import numpy as np import numpy.typing as npt @@ -25,6 +26,8 @@ from invokeai.app.invocations.noise import get_noise from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image +from invokeai.backend.lora import LoRAModelRaw +from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending from invokeai.backend.tiles.utils import Tile @@ -236,11 +239,19 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): # Crop the global noise into tiles. noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles] + # Prepare an iterator that yields the UNet's LoRA models and their weights. + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + for lora in self.unet.loras: + lora_info = context.models.load(lora.lora) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) + del lora_info + # Load the UNet model. unet_info = context.models.load(self.unet.unet) refined_latent_tiles: list[torch.Tensor] = [] - with ExitStack() as exit_stack, unet_info as unet: + with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()): assert isinstance(unet, UNet2DConditionModel) scheduler = get_scheduler( context=context,