mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): fix model load events on sdxl nodes
they need the `context` to be provided to emit socket events
This commit is contained in:
parent
016797c890
commit
840205496a
@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation):
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
**clip_field.tokenizer.dict(), context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
**clip_field.text_encoder.dict(), context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -196,6 +196,7 @@ class SDXLPromptInvocationBase:
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
@ -240,16 +241,16 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
**clip_field.tokenizer.dict(), context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
**clip_field.text_encoder.dict(), context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -265,6 +266,7 @@ class SDXLPromptInvocationBase:
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
|
@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
**self.unet.unet.dict(), context=context
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
@ -555,7 +555,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
del noise
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
**self.unet.unet.dict(), context=context,
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
|
Loading…
Reference in New Issue
Block a user