feat(nodes): fix model load events on sdxl nodes

they need the `context` to be provided to emit socket events
This commit is contained in:
psychedelicious 2023-07-26 00:21:09 +10:00
parent 016797c890
commit 840205496a
2 changed files with 11 additions and 9 deletions

View File

@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled): def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(), **clip_field.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -196,6 +196,7 @@ class SDXLPromptInvocationBase:
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:
@ -240,16 +241,16 @@ class SDXLPromptInvocationBase:
def run_clip_compel(self, context, clip_field, prompt, get_pooled): def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(), **clip_field.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -265,6 +266,7 @@ class SDXLPromptInvocationBase:
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:

View File

@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context
) )
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None
@ -555,7 +555,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
del noise del noise
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None