feat(nodes): update invocation context for mm2, update nodes model usage

This commit is contained in:
psychedelicious
2024-02-15 20:43:41 +11:00
parent 88d6de4101
commit 539570cc7a
9 changed files with 141 additions and 147 deletions

View File

@ -69,20 +69,12 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.load.load_model_by_key(
**self.clip.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(),
context=context,
)
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}), context=context
)
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@ -94,10 +86,7 @@ class CompelInvocation(BaseInvocation):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
loaded_model = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(),
context=context,
).model
loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model
assert isinstance(loaded_model, TextualInversionModelRaw)
ti_list.append((name, loaded_model))
except UnknownModelException:
@ -165,14 +154,8 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.services.model_manager.load.load_model_by_key(
**clip_field.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_manager.load.load_model_by_key(
**clip_field.text_encoder.model_dump(),
context=context,
)
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
# return zero on empty
if prompt == "" and zero_on_empty:
@ -197,9 +180,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}), context=context
)
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
@ -212,11 +193,8 @@ class SDXLPromptInvocationBase:
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_model = context.services.model_manager.load.load_model_by_attr(
model_name=name,
base_model=text_encoder_info.config.base,
model_type=ModelType.TextualInversion,
context=context,
ti_model = context.models.load_by_attrs(
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
).model
assert isinstance(ti_model, TextualInversionModelRaw)
ti_list.append((name, ti_model))