From 840205496a5f7250df4bbc405091af4a0213cf36 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 26 Jul 2023 00:21:09 +1000 Subject: [PATCH] feat(nodes): fix model load events on sdxl nodes they need the `context` to be provided to emit socket events --- invokeai/app/invocations/compel.py | 16 +++++++++------- invokeai/app/invocations/sdxl.py | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index fa1b6939d2..6aadbf509d 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation): def _lora_loader(): for lora in self.clip.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return @@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation): class SDXLPromptInvocationBase: def run_clip_raw(self, context, clip_field, prompt, get_pooled): tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.dict(), + **clip_field.tokenizer.dict(), context=context, ) text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.dict(), + **clip_field.text_encoder.dict(), context=context, ) def _lora_loader(): for lora in clip_field.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return @@ -196,6 +196,7 @@ class SDXLPromptInvocationBase: model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, + context=context, ).context.model ) except ModelNotFoundException: @@ -240,16 +241,16 @@ class SDXLPromptInvocationBase: def run_clip_compel(self, context, clip_field, prompt, get_pooled): tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.dict(), + **clip_field.tokenizer.dict(), context=context, ) text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.dict(), + **clip_field.text_encoder.dict(), context=context, ) def _lora_loader(): for lora in clip_field.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return @@ -265,6 +266,7 @@ class SDXLPromptInvocationBase: model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, + context=context, ).context.model ) except ModelNotFoundException: diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 3a63860053..249e864799 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict() + **self.unet.unet.dict(), context=context ) do_classifier_free_guidance = True cross_attention_kwargs = None @@ -555,7 +555,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): del noise unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict() + **self.unet.unet.dict(), context=context, ) do_classifier_free_guidance = True cross_attention_kwargs = None