BREAKING CHANGES: invocations now require model key, not base/type/name

- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
This commit is contained in:
Lincoln Stein
2024-02-05 22:56:32 -05:00
committed by psychedelicious
parent fbded1c0f2
commit dfcf38be91
31 changed files with 727 additions and 496 deletions

View File

@ -1,9 +1,10 @@
from typing import List, Optional, Union
from typing import Iterator, List, Optional, Tuple, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@ -12,18 +13,21 @@ from invokeai.app.invocations.fields import (
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
from invokeai.backend.embeddings.lora import LoRAModelRaw
from invokeai.backend.embeddings.model_patcher import ModelPatcher
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager import ModelType
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from invokeai.backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -64,13 +68,22 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
tokenizer_info = context.services.model_records.load_model(
**self.clip.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_records.load_model(
**self.clip.text_encoder.model_dump(),
context=context,
)
def _lora_loader():
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
lora_info = context.services.model_records.load_model(
**lora.model_dump(exclude={"weight"}), context=context
)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
@ -80,24 +93,20 @@ class CompelInvocation(BaseInvocation):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.models.load(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model,
)
)
except ModelNotFoundException:
loaded_model = context.services.model_records.load_model(
**self.clip.text_encoder.model_dump(),
context=context,
).model
assert isinstance(loaded_model, TextualInversionModelRaw)
ti_list.append((name, loaded_model))
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with (
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer,
ti_manager,
),
@ -105,7 +114,7 @@ class CompelInvocation(BaseInvocation):
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,
@ -144,6 +153,8 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
"""Prompt processor for SDXL models."""
def run_clip_compel(
self,
context: InvocationContext,
@ -152,20 +163,27 @@ class SDXLPromptInvocationBase:
get_pooled: bool,
lora_prefix: str,
zero_on_empty: bool,
):
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.services.model_records.load_model(
**clip_field.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_records.load_model(
**clip_field.text_encoder.model_dump(),
context=context,
)
# return zero on empty
if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.context.model
cpu_text_encoder = text_encoder_info.model
assert isinstance(cpu_text_encoder, torch.nn.Module)
c = torch.zeros(
(
1,
cpu_text_encoder.config.max_position_embeddings,
cpu_text_encoder.config.hidden_size,
),
dtype=text_encoder_info.context.cache.precision,
dtype=cpu_text_encoder.dtype,
)
if get_pooled:
c_pooled = torch.zeros(
@ -176,10 +194,14 @@ class SDXLPromptInvocationBase:
c_pooled = None
return c, c_pooled, None
def _lora_loader():
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
lora_info = context.services.model_records.load_model(
**lora.model_dump(exclude={"weight"}), context=context
)
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
del lora_info
return
@ -189,24 +211,24 @@ class SDXLPromptInvocationBase:
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.models.load(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model,
)
)
except ModelNotFoundException:
ti_model = context.services.model_records.load_model_by_attr(
model_name=name,
base_model=text_encoder_info.config.base,
model_type=ModelType.TextualInversion,
context=context,
).model
assert isinstance(ti_model, TextualInversionModelRaw)
ti_list.append((name, ti_model))
except UnknownModelException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
logger.warning(f'trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
with (
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer,
ti_manager,
),
@ -214,7 +236,7 @@ class SDXLPromptInvocationBase:
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
):
compel = Compel(
tokenizer=tokenizer,
@ -332,6 +354,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
dim=1,
)
assert c2_pooled is not None
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
@ -380,6 +403,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
assert c2_pooled is not None
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(