make model manager v2 ready for PR review

- Replace legacy model manager service with the v2 manager.

- Update invocations to use new load interface.

- Fixed many but not all type checking errors in the invocations. Most
  were unrelated to model manager

- Updated routes. All the new routes live under the route tag
  `model_manager_v2`. To avoid confusion with the old routes,
  they have the URL prefix `/api/v2/models`. The old routes
  have been de-registered.

- Added a pytest for the loader.

- Updated documentation in contributing/MODEL_MANAGER.md
This commit is contained in:
Lincoln Stein
2024-02-10 18:09:45 -05:00
committed by psychedelicious
parent 2b1dc74080
commit 94e8d1b6d5
36 changed files with 680 additions and 435 deletions

View File

@ -3,6 +3,7 @@ from typing import Iterator, List, Optional, Tuple, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTokenizer
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.fields import (
@ -68,18 +69,18 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_records.load_model(
tokenizer_info = context.services.model_manager.load.load_model_by_key(
**self.clip.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_records.load_model(
text_encoder_info = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(),
context=context,
)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.services.model_records.load_model(
lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}), context=context
)
assert isinstance(lora_info.model, LoRAModelRaw)
@ -93,7 +94,7 @@ class CompelInvocation(BaseInvocation):
for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1]
try:
loaded_model = context.services.model_records.load_model(
loaded_model = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(),
context=context,
).model
@ -164,11 +165,11 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.services.model_records.load_model(
tokenizer_info = context.services.model_manager.load.load_model_by_key(
**clip_field.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_records.load_model(
text_encoder_info = context.services.model_manager.load.load_model_by_key(
**clip_field.text_encoder.model_dump(),
context=context,
)
@ -196,7 +197,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.services.model_records.load_model(
lora_info = context.services.model_manager.load.load_model_by_key(
**lora.model_dump(exclude={"weight"}), context=context
)
lora_model = lora_info.model
@ -211,7 +212,7 @@ class SDXLPromptInvocationBase:
for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1]
try:
ti_model = context.services.model_records.load_model_by_attr(
ti_model = context.services.model_manager.load.load_model_by_attr(
model_name=name,
base_model=text_encoder_info.config.base,
model_type=ModelType.TextualInversion,
@ -448,9 +449,9 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count(
tokenizer,
tokenizer: CLIPTokenizer,
prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long=False,
truncate_if_too_long: bool = False,
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
@ -462,7 +463,9 @@ def get_max_token_count(
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
def get_tokens_for_prompt_object(
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
) -> List[str]:
if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
@ -475,24 +478,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
for x in parsed_prompt.children
]
text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text)
tokens: List[str] = tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
assert display_label_prefix is not None
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
@ -532,7 +540,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
def log_tokenization_for_text(
text: str,
tokenizer: CLIPTokenizer,
display_label: Optional[str] = None,
truncate_if_too_long: Optional[bool] = False,
) -> None:
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '