mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): fix model load events on sdxl nodes
they need the `context` to be provided to emit socket events
This commit is contained in:
parent
016797c890
commit
840205496a
@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"}), context=context)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation):
|
|||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(), context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(),
|
**clip_field.text_encoder.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"}), context=context)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
@ -196,6 +196,7 @@ class SDXLPromptInvocationBase:
|
|||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
@ -240,16 +241,16 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(), context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(),
|
**clip_field.text_encoder.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"}), context=context)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
@ -265,6 +266,7 @@ class SDXLPromptInvocationBase:
|
|||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
|
@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict()
|
**self.unet.unet.dict(), context=context
|
||||||
)
|
)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
@ -555,7 +555,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
del noise
|
del noise
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict()
|
**self.unet.unet.dict(), context=context,
|
||||||
)
|
)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
|
Loading…
Reference in New Issue
Block a user