mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): update invocation context for mm2, update nodes model usage
This commit is contained in:
@ -69,20 +69,12 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@ -94,10 +86,7 @@ class CompelInvocation(BaseInvocation):
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.services.model_manager.load.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
).model
|
||||
loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model
|
||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, loaded_model))
|
||||
except UnknownModelException:
|
||||
@ -165,14 +154,8 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@ -197,9 +180,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
@ -212,11 +193,8 @@ class SDXLPromptInvocationBase:
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_model = context.services.model_manager.load.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=text_encoder_info.config.base,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
ti_model = context.models.load_by_attrs(
|
||||
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
||||
).model
|
||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, ti_model))
|
||||
|
Reference in New Issue
Block a user