From 430b9c291f91ad90e669486622154003dc147937 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Thu, 13 Jul 2023 22:59:38 +1200 Subject: [PATCH] fix: Loras not working correctly with Inpainting --- invokeai/app/invocations/generate.py | 90 ++++++++++++++-------------- 1 file changed, 46 insertions(+), 44 deletions(-) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 8accdb9851..6cdb83effc 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -154,40 +154,42 @@ class InpaintInvocation(BaseInvocation): @contextmanager def load_model_old_way(self, context, scheduler): + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) - #unet = unet_info.context.model - #vae = vae_info.context.model + with vae_info as vae,\ + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: - with ExitStack() as stack: - loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] + device = context.services.model_manager.mgr.cache.execution_device + dtype = context.services.model_manager.mgr.cache.precision - with vae_info as vae,\ - unet_info as unet,\ - ModelPatcher.apply_lora_unet(unet, loras): + pipeline = StableDiffusionGeneratorPipeline( + vae=vae, + text_encoder=None, + tokenizer=None, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + precision="float16" if dtype == torch.float16 else "float32", + execution_device=device, + ) - device = context.services.model_manager.mgr.cache.execution_device - dtype = context.services.model_manager.mgr.cache.precision - - pipeline = StableDiffusionGeneratorPipeline( - vae=vae, - text_encoder=None, - tokenizer=None, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - precision="float16" if dtype == torch.float16 else "float32", - execution_device=device, - ) - - yield OldModelInfo( - name=self.unet.unet.model_name, - hash="", - model=pipeline, - ) + yield OldModelInfo( + name=self.unet.unet.model_name, + hash="", + model=pipeline, + ) def invoke(self, context: InvocationContext) -> ImageOutput: image = ( @@ -226,21 +228,21 @@ class InpaintInvocation(BaseInvocation): ), # Shorthand for passing all of the parameters above manually ) - # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object - # each time it is called. We only need the first one. - generator_output = next(outputs) + # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object + # each time it is called. We only need the first one. + generator_output = next(outputs) - image_dto = context.services.images.create( - image=generator_output.image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - session_id=context.graph_execution_state_id, - node_id=self.id, - is_intermediate=self.is_intermediate, - ) + image_dto = context.services.images.create( + image=generator_output.image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + session_id=context.graph_execution_state_id, + node_id=self.id, + is_intermediate=self.is_intermediate, + ) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + )