From 79d028ecbdd5cbe26765c477f8cbffb27ac225f2 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 5 Feb 2024 22:56:32 -0500 Subject: [PATCH] BREAKING CHANGES: invocations now require model key, not base/type/name - Implement new model loader and modify invocations and embeddings - Finish implementation loaders for all models currently supported by InvokeAI. - Move lora, textual_inversion, and model patching support into backend/embeddings. - Restore support for model cache statistics collection (a little ugly, needs work). - Fixed up invocations that load and patch models. - Move seamless and silencewarnings utils into better location --- invokeai/app/api/routers/download_queue.py | 2 +- invokeai/app/invocations/compel.py | 102 +++++----- .../controlnet_image_processors.py | 8 +- invokeai/app/invocations/ip_adapter.py | 52 +++--- invokeai/app/invocations/latent.py | 66 +++---- invokeai/app/invocations/model.py | 174 +++++------------- invokeai/app/invocations/onnx.py | 97 +++------- invokeai/app/invocations/sdxl.py | 74 ++------ invokeai/app/invocations/t2i_adapter.py | 8 +- invokeai/app/services/events/events_base.py | 27 +-- .../invocation_stats_default.py | 16 +- .../latents_storage/latents_storage_base.py | 8 +- .../latents_storage/latents_storage_disk.py | 3 +- .../latents_storage_forward_cache.py | 7 +- .../model_records/model_records_base.py | 47 ++++- .../model_records/model_records_sql.py | 92 ++++++++- invokeai/backend/embeddings/__init__.py | 4 + invokeai/backend/embeddings/embedding_base.py | 12 ++ invokeai/backend/embeddings/lora.py | 14 +- invokeai/backend/embeddings/model_patcher.py | 134 +++----------- .../backend/embeddings/textual_inversion.py | 100 ++++++++++ invokeai/backend/install/install_helper.py | 3 +- invokeai/backend/model_manager/config.py | 5 +- .../backend/model_manager/load/load_base.py | 4 +- .../model_manager/load/load_default.py | 4 +- .../load/model_cache/__init__.py | 4 +- .../load/model_cache/model_cache_base.py | 33 +++- .../load/model_cache/model_cache_default.py | 53 +++--- .../load/model_loaders/textual_inversion.py | 2 +- invokeai/backend/stable_diffusion/__init__.py | 9 + invokeai/backend/stable_diffusion/seamless.py | 102 ++++++++++ invokeai/backend/util/silence_warnings.py | 28 +++ invokeai/frontend/install/model_install2.py | 8 +- .../util/test_hf_model_select.py | 2 + tests/test_model_probe.py | 6 +- 35 files changed, 728 insertions(+), 582 deletions(-) create mode 100644 invokeai/backend/embeddings/__init__.py create mode 100644 invokeai/backend/embeddings/embedding_base.py create mode 100644 invokeai/backend/embeddings/textual_inversion.py create mode 100644 invokeai/backend/stable_diffusion/seamless.py create mode 100644 invokeai/backend/util/silence_warnings.py diff --git a/invokeai/app/api/routers/download_queue.py b/invokeai/app/api/routers/download_queue.py index 92b658c370..2dba376c18 100644 --- a/invokeai/app/api/routers/download_queue.py +++ b/invokeai/app/api/routers/download_queue.py @@ -55,7 +55,7 @@ async def download( ) -> DownloadJob: """Download the source URL to the file or directory indicted in dest.""" queue = ApiDependencies.invoker.services.download_queue - return queue.download(source, dest, priority, access_token) + return queue.download(source, Path(dest), priority, access_token) @download_queue_router.get( diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 49c62cff56..12dcd9f930 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,22 +1,26 @@ from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Iterator, List, Optional, Tuple, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment +import invokeai.backend.util.logging as logger from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput +from invokeai.app.services.model_records import UnknownModelException from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt +from invokeai.backend.embeddings.lora import LoRAModelRaw +from invokeai.backend.embeddings.model_patcher import ModelPatcher +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw +from invokeai.backend.model_manager import ModelType from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ExtraConditioningInfo, SDXLConditioningInfo, ) +from invokeai.backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import ModelNotFoundException, ModelType -from ...backend.util.devices import torch_dtype -from ..util.ti_utils import extract_ti_triggers_from_prompt from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -66,21 +70,22 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( + tokenizer_info = context.services.model_records.load_model( **self.clip.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.services.model_records.load_model( **self.clip.text_encoder.model_dump(), context=context, ) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( + lora_info = context.services.model_records.load_model( **lora.model_dump(exclude={"weight"}), context=context ) - yield (lora_info.context.model, lora.weight) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) del lora_info return @@ -90,25 +95,20 @@ class CompelInvocation(BaseInvocation): for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model, - ) - ) - except ModelNotFoundException: + loaded_model = context.services.model_records.load_model( + **self.clip.text_encoder.model_dump(), + context=context, + ).model + assert isinstance(loaded_model, TextualInversionModelRaw) + ti_list.append((name, loaded_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -116,7 +116,7 @@ class CompelInvocation(BaseInvocation): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -150,7 +150,7 @@ class CompelInvocation(BaseInvocation): ) conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + context.services.latents.save(conditioning_name, conditioning_data) # TODO: fix type mismatch here return ConditioningOutput( conditioning=ConditioningField( @@ -160,6 +160,8 @@ class CompelInvocation(BaseInvocation): class SDXLPromptInvocationBase: + """Prompt processor for SDXL models.""" + def run_clip_compel( self, context: InvocationContext, @@ -168,26 +170,27 @@ class SDXLPromptInvocationBase: get_pooled: bool, lora_prefix: str, zero_on_empty: bool, - ): - tokenizer_info = context.services.model_manager.get_model( + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: + tokenizer_info = context.services.model_records.load_model( **clip_field.tokenizer.model_dump(), context=context, ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.services.model_records.load_model( **clip_field.text_encoder.model_dump(), context=context, ) # return zero on empty if prompt == "" and zero_on_empty: - cpu_text_encoder = text_encoder_info.context.model + cpu_text_encoder = text_encoder_info.model + assert isinstance(cpu_text_encoder, torch.nn.Module) c = torch.zeros( ( 1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size, ), - dtype=text_encoder_info.context.cache.precision, + dtype=cpu_text_encoder.dtype, ) if get_pooled: c_pooled = torch.zeros( @@ -198,12 +201,14 @@ class SDXLPromptInvocationBase: c_pooled = None return c, c_pooled, None - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( + lora_info = context.services.model_records.load_model( **lora.model_dump(exclude={"weight"}), context=context ) - yield (lora_info.context.model, lora.weight) + lora_model = lora_info.model + assert isinstance(lora_model, LoRAModelRaw) + yield (lora_model, lora.weight) del lora_info return @@ -213,25 +218,24 @@ class SDXLPromptInvocationBase: for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=clip_field.text_encoder.base_model, - model_type=ModelType.TextualInversion, - context=context, - ).context.model, - ) - ) - except ModelNotFoundException: + ti_model = context.services.model_records.load_model_by_attr( + model_name=name, + base_model=text_encoder_info.config.base, + model_type=ModelType.TextualInversion, + context=context, + ).model + assert isinstance(ti_model, TextualInversionModelRaw) + ti_list.append((name, ti_model)) + except UnknownModelException: # print(e) # import traceback # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') + logger.warning(f'trigger: "{trigger}" not found') + except ValueError: + logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( tokenizer, ti_manager, ), @@ -239,7 +243,7 @@ class SDXLPromptInvocationBase: # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -357,6 +361,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): dim=1, ) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -410,6 +415,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 00c3fa74f6..b9c20c7995 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -24,7 +24,7 @@ from controlnet_aux import ( ) from controlnet_aux.util import HWC3, ade_palette from PIL import Image -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights @@ -32,7 +32,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.image_util.depth_anything import DepthAnythingDetector -from ...backend.model_management import BaseModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -57,10 +56,7 @@ CONTROLNET_RESIZE_VALUES = Literal[ class ControlNetModelField(BaseModel): """ControlNet model field""" - model_name: str = Field(description="Name of the ControlNet model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model config record key for the ControlNet model") class ControlField(BaseModel): diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 6bd2889624..9404728824 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -2,7 +2,8 @@ import os from builtins import float from typing import List, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator +from typing_extensions import Self from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -17,22 +18,16 @@ from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.shared.fields import FieldDescriptions -from invokeai.backend.model_management.models.base import BaseModelType, ModelType -from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id +from invokeai.backend.model_manager import BaseModelType, ModelType +# LS: Consider moving these two classes into model.py class IPAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the IP-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the IP-Adapter model") class CLIPVisionModelField(BaseModel): - model_name: str = Field(description="Name of the CLIP Vision image encoder model") - base_model: BaseModelType = Field(description="Base model (usually 'Any')") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Key to the CLIP Vision image encoder model") class IPAdapterField(BaseModel): @@ -49,16 +44,26 @@ class IPAdapterField(BaseModel): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self +def get_ip_adapter_image_encoder_model_id(model_path: str): + """Read the ID of the image encoder associated with the IP-Adapter at `model_path`.""" + image_encoder_config_file = os.path.join(model_path, "image_encoder.txt") + + with open(image_encoder_config_file, "r") as f: + image_encoder_model = f.readline().strip() + + return image_encoder_model + + @invocation_output("ip_adapter_output") class IPAdapterOutput(BaseInvocationOutput): # Outputs @@ -87,33 +92,36 @@ class IPAdapterInvocation(BaseInvocation): @field_validator("weight") @classmethod - def validate_ip_adapter_weight(cls, v): + def validate_ip_adapter_weight(cls, v: float) -> float: validate_weights(v) return v @model_validator(mode="after") - def validate_begin_end_step_percent(self): + def validate_begin_end_step_percent(self) -> Self: validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.model_info( - self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter - ) + ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key) # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model # directly, and 2) we are reading from disk every time this invocation is called without caching the result. # A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this # is currently messy due to differences between how the model info is generated when installing a model from # disk vs. downloading the model. + # TODO (LS): Fix the issue above by: + # 1. Change IPAdapterConfig definition to include a field for the repo_id of the image encoder model. + # 2. Update probe.py to read `image_encoder.txt` and store it in the config. + # 3. Change below to get the image encoder from the configuration record. image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"]) + os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info.path) ) image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_model = CLIPVisionModelField( - model_name=image_encoder_model_name, - base_model=BaseModelType.Any, + image_encoder_models = context.services.model_records.search_by_attr( + model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) + assert len(image_encoder_models) == 1 + image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key) return IPAdapterOutput( ip_adapter=IPAdapterField( image=self.image, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b77363ceb8..a621f9fe71 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,13 +3,13 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import Iterator, List, Literal, Optional, Tuple, Union import einops import numpy as np import torch import torchvision.transforms as T -from diffusers import AutoencoderKL, AutoencoderTiny +from diffusers import AutoencoderKL, AutoencoderTiny, UNet2DConditionModel from diffusers.image_processor import VaeImageProcessor from diffusers.models.adapter import T2IAdapter from diffusers.models.attention_processor import ( @@ -38,14 +38,13 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.embeddings.model_patcher import ModelPatcher from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus -from invokeai.backend.model_management.models import ModelType, SilenceWarnings +from invokeai.backend.model_manager import AnyModel, BaseModelType +from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.util.silence_warnings import SilenceWarnings -from ...backend.model_management.lora import ModelPatcher -from ...backend.model_management.models import BaseModelType -from ...backend.model_management.seamless import set_seamless -from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, IPAdapterData, @@ -77,7 +76,7 @@ if choose_torch_device() == torch.device("mps"): DEFAULT_PRECISION = choose_precision(choose_torch_device()) -SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] +SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] # FIXME: "Invalid type alias" # HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to # be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale @@ -156,7 +155,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ) if image is not None: - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_records.load_model( **self.vae.vae.model_dump(), context=context, ) @@ -189,7 +188,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.get_model( + orig_scheduler_info = context.services.model_records.load_model( **scheduler_info.model_dump(), context=context, ) @@ -422,10 +421,8 @@ class DenoiseLatentsInvocation(BaseInvocation): controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_manager.get_model( - model_name=control_info.control_model.model_name, - model_type=ModelType.ControlNet, - base_model=control_info.control_model.base_model, + context.services.model_records.load_model( + key=control_info.control_model.key, context=context, ) ) @@ -490,18 +487,14 @@ class DenoiseLatentsInvocation(BaseInvocation): conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.get_model( - model_name=single_ip_adapter.ip_adapter_model.model_name, - model_type=ModelType.IPAdapter, - base_model=single_ip_adapter.ip_adapter_model.base_model, + context.services.model_records.load_model( + key=single_ip_adapter.ip_adapter_model.key, context=context, ) ) - image_encoder_model_info = context.services.model_manager.get_model( - model_name=single_ip_adapter.image_encoder_model.model_name, - model_type=ModelType.CLIPVision, - base_model=single_ip_adapter.image_encoder_model.base_model, + image_encoder_model_info = context.services.model_records.load_model( + key=single_ip_adapter.image_encoder_model.key, context=context, ) @@ -554,10 +547,8 @@ class DenoiseLatentsInvocation(BaseInvocation): t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.get_model( - model_name=t2i_adapter_field.t2i_adapter_model.model_name, - model_type=ModelType.T2IAdapter, - base_model=t2i_adapter_field.t2i_adapter_model.base_model, + t2i_adapter_model_info = context.services.model_records.load_model( + key=t2i_adapter_field.t2i_adapter_model.key, context=context, ) image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) @@ -593,7 +584,7 @@ class DenoiseLatentsInvocation(BaseInvocation): do_classifier_free_guidance=False, width=t2i_input_width, height=t2i_input_height, - num_channels=t2i_adapter_model.config.in_channels, + num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict device=t2i_adapter_model.device, dtype=t2i_adapter_model.dtype, resize_mode=t2i_adapter_field.resize_mode, @@ -703,28 +694,30 @@ class DenoiseLatentsInvocation(BaseInvocation): def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[AnyModel, float]]: for lora in self.unet.loras: - lora_info = context.services.model_manager.get_model( + lora_info = context.services.model_records.load_model( **lora.model_dump(exclude={"weight"}), context=context, ) - yield (lora_info.context.model, lora.weight) + yield (lora_info.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.get_model( + unet_info = context.services.model_records.load_model( **self.unet.unet.model_dump(), context=context, ) + assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), - set_seamless(unet_info.context.model, self.unet.seamless_axes), + ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config), + set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME unet_info as unet, # Apply the LoRA after unet has been moved to its target device for faster patching. ModelPatcher.apply_lora_unet(unet, _lora_loader()), ): + assert isinstance(unet, torch.Tensor) latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) @@ -822,12 +815,13 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_records.load_model( **self.vae.vae.model_dump(), context=context, ) - with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: + with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + assert isinstance(vae, torch.Tensor) latents = latents.to(vae.device) if self.fp32: vae.to(dtype=torch.float32) @@ -1063,7 +1057,7 @@ class ImageToLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.services.images.get_pil_image(self.image.image_name) - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_records.load_model( **self.vae.vae.model_dump(), context=context, ) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 99dcc72999..e0e61ea26c 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,12 +1,12 @@ import copy from typing import List, Optional -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, Field from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.shared.models import FreeUConfig -from ...backend.model_management import BaseModelType, ModelType, SubModelType +from ...backend.model_manager import SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -20,13 +20,9 @@ from .baseinvocation import ( class ModelInfo(BaseModel): - model_name: str = Field(description="Info to load submodel") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Info to load submodel") + key: str = Field(description="Info to load submodel") submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") - model_config = ConfigDict(protected_namespaces=()) - class LoraInfo(ModelInfo): weight: float = Field(description="Lora's weight which to use when apply to model") @@ -55,7 +51,7 @@ class VaeField(BaseModel): @invocation_output("unet_output") class UNetOutput(BaseInvocationOutput): - """Base class for invocations that output a UNet field""" + """Base class for invocations that output a UNet field.""" unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") @@ -84,20 +80,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): class MainModelField(BaseModel): """Main model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model key") class LoRAModelField(BaseModel): """LoRA model field""" - model_name: str = Field(description="Name of the LoRA model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="LoRA model key") @invocation( @@ -114,74 +103,31 @@ class MainModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ + if not context.services.model_records.exists(key): + raise Exception(f"Unknown model {key}") return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.TextEncoder, ), loras=[], @@ -189,9 +135,7 @@ class MainModelLoaderInvocation(BaseInvocation): ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=key, submodel=SubModelType.Vae, ), ), @@ -229,21 +173,16 @@ class LoraLoaderInvocation(BaseInvocation): if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unkown lora name: {lora_name}!") + if not context.services.model_records.exists(lora_key): + raise Exception(f"Unkown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') output = LoraLoaderOutput() @@ -251,9 +190,7 @@ class LoraLoaderInvocation(BaseInvocation): output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -263,9 +200,7 @@ class LoraLoaderInvocation(BaseInvocation): output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -318,24 +253,19 @@ class SDXLLoraLoaderInvocation(BaseInvocation): if self.lora is None: raise Exception("No LoRA provided") - base_model = self.lora.base_model - lora_name = self.lora.model_name + lora_key = self.lora.key - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, - ): - raise Exception(f"Unknown lora name: {lora_name}!") + if not context.services.model_records.exists(lora_key): + raise Exception(f"Unknown lora: {lora_key}!") - if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): - raise Exception(f'Lora "{lora_name}" already applied to unet') + if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip') + if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip') - if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras): - raise Exception(f'Lora "{lora_name}" already applied to clip2') + if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras): + raise Exception(f'Lora "{lora_key}" already applied to clip2') output = SDXLLoraLoaderOutput() @@ -343,9 +273,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -355,9 +283,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -367,9 +293,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): output.clip2 = copy.deepcopy(self.clip2) output.clip2.loras.append( LoraInfo( - base_model=base_model, - model_name=lora_name, - model_type=ModelType.Lora, + key=lora_key, submodel=None, weight=self.weight, ) @@ -381,10 +305,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): class VAEModelField(BaseModel): """Vae model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model's key") @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") @@ -398,25 +319,12 @@ class VaeLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> VAEOutput: - base_model = self.vae_model.base_model - model_name = self.vae_model.model_name - model_type = ModelType.Vae + key = self.vae_model.key - if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=model_name, - model_type=model_type, - ): - raise Exception(f"Unkown vae name: {model_name}!") - return VAEOutput( - vae=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ) - ) - ) + if not context.services.model_records.exists(key): + raise Exception(f"Unkown vae: {key}!") + + return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) @invocation_output("seamless_output") diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 759cfde700..5d39a3d7e7 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -8,16 +8,16 @@ from typing import List, Literal, Union import numpy as np import torch from diffusers.image_processor import VaeImageProcessor -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, Field, field_validator from tqdm import tqdm from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend import BaseModelType, ModelType, SubModelType +from invokeai.backend import ModelType, SubModelType +from invokeai.backend.embeddings.model_patcher import ONNXModelPatcher -from ...backend.model_management import ONNXModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util import choose_torch_device from ..util.ti_utils import extract_ti_triggers_from_prompt @@ -62,16 +62,16 @@ class ONNXPromptInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( + tokenizer_info = context.services.model_records.load_model( **self.clip.tokenizer.model_dump(), ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.services.model_records.load_model( **self.clip.text_encoder.model_dump(), ) with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: loras = [ ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, + context.services.model_records.load_model(**lora.model_dump(exclude={"weight"})).model, lora.weight, ) for lora in self.clip.loras @@ -84,11 +84,11 @@ class ONNXPromptInvocation(BaseInvocation): ti_list.append( ( name, - context.services.model_manager.get_model( + context.services.model_records.load_model_by_attr( model_name=name, - base_model=self.clip.text_encoder.base_model, + base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion, - ).context.model, + ).model, ) ) except Exception: @@ -257,13 +257,13 @@ class ONNXTextToLatentsInvocation(BaseInvocation): eta=0.0, ) - unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump()) + unet_info = context.services.model_records.load_model(**self.unet.unet.model_dump()) with unet_info as unet: # , ExitStack() as stack: # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] loras = [ ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, + context.services.model_records.load_model(**lora.model_dump(exclude={"weight"})).model, lora.weight, ) for lora in self.unet.loras @@ -344,9 +344,9 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: - raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}") + raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}") - vae_info = context.services.model_manager.get_model( + vae_info = context.services.model_records.load_model( **self.vae.vae.model_dump(), ) @@ -400,11 +400,7 @@ class ONNXModelLoaderOutput(BaseInvocationOutput): class OnnxModelField(BaseModel): """Onnx model field""" - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model ID") @invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0") @@ -416,74 +412,31 @@ class OnnxModelLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.ONNX + model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ + if not context.services.model_records.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return ONNXModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder, ), loras=[], @@ -491,17 +444,13 @@ class OnnxModelLoaderInvocation(BaseInvocation): ), vae_decoder=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.VaeDecoder, ), ), vae_encoder=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.VaeEncoder, ), ), diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 68076fdfeb..4cb5efbbb6 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,6 +1,6 @@ from invokeai.app.shared.fields import FieldDescriptions +from invokeai.backend.model_manager import SubModelType -from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -44,45 +44,31 @@ class SDXLModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.services.model_records.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder, ), loras=[], @@ -90,15 +76,11 @@ class SDXLModelLoaderInvocation(BaseInvocation): ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder2, ), loras=[], @@ -106,9 +88,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Vae, ), ), @@ -133,45 +113,31 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): # TODO: precision? def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.Main + model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + if not context.services.model_records.exists(model_key): + raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.TextEncoder2, ), loras=[], @@ -179,9 +145,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ), vae=VaeField( vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, + key=model_key, submodel=SubModelType.Vae, ), ), diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index e055d23903..09819672b7 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -1,6 +1,6 @@ from typing import Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -16,14 +16,10 @@ from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESI from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.shared.fields import FieldDescriptions -from invokeai.backend.model_management.models.base import BaseModelType class T2IAdapterModelField(BaseModel): - model_name: str = Field(description="Name of the T2I-Adapter model") - base_model: BaseModelType = Field(description="Base model") - - model_config = ConfigDict(protected_namespaces=()) + key: str = Field(description="Model record key for the T2I-Adapter model") class T2IAdapterField(BaseModel): diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index e9365f3349..af6fe4923f 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -11,8 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( SessionQueueStatus, ) from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_management.model_manager import ModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager import AnyModelConfig class EventServiceBase: @@ -171,10 +170,7 @@ class EventServiceBase: queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is requested""" self.__emit_queue_event( @@ -184,10 +180,7 @@ class EventServiceBase: "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, + "model_config": model_config.model_dump(), }, ) @@ -197,11 +190,7 @@ class EventServiceBase: queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: SubModelType, - model_info: ModelInfo, + model_config: AnyModelConfig, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" self.__emit_queue_event( @@ -211,13 +200,7 @@ class EventServiceBase: "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "model_name": model_name, - "base_model": base_model, - "model_type": model_type, - "submodel": submodel, - "hash": model_info.hash, - "location": str(model_info.location), - "precision": str(model_info.precision), + "model_config": model_config.model_dump(), }, ) diff --git a/invokeai/app/services/invocation_stats/invocation_stats_default.py b/invokeai/app/services/invocation_stats/invocation_stats_default.py index 501a4c04e5..8883ebe295 100644 --- a/invokeai/app/services/invocation_stats/invocation_stats_default.py +++ b/invokeai/app/services/invocation_stats/invocation_stats_default.py @@ -2,6 +2,7 @@ import json import time from contextlib import contextmanager from pathlib import Path +from typing import Iterator import psutil import torch @@ -10,7 +11,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invoker import Invoker from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError -from invokeai.backend.model_management.model_cache import CacheStats +from invokeai.backend.model_manager.load.model_cache import CacheStats from .invocation_stats_base import InvocationStatsServiceBase from .invocation_stats_common import ( @@ -41,7 +42,10 @@ class InvocationStatsService(InvocationStatsServiceBase): self._invoker = invoker @contextmanager - def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str): + def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]: + services = self._invoker.services + if services.model_records is None or services.model_records.loader is None: + yield None if not self._stats.get(graph_execution_state_id): # First time we're seeing this graph_execution_state_id. self._stats[graph_execution_state_id] = GraphExecutionStats() @@ -55,8 +59,10 @@ class InvocationStatsService(InvocationStatsServiceBase): start_ram = psutil.Process().memory_info().rss if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() - if self._invoker.services.model_manager: - self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id]) + + # TO DO [LS]: clean up loader service - shouldn't be an attribute of model records + assert services.model_records.loader is not None + services.model_records.loader.ram_cache.stats = self._cache_stats[graph_execution_state_id] try: # Let the invocation run. @@ -73,7 +79,7 @@ class InvocationStatsService(InvocationStatsServiceBase): ) self._stats[graph_execution_state_id].add_node_execution_stats(node_stats) - def _prune_stale_stats(self): + def _prune_stale_stats(self) -> None: """Check all graphs being tracked and prune any that have completed/errored. This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so diff --git a/invokeai/app/services/latents_storage/latents_storage_base.py b/invokeai/app/services/latents_storage/latents_storage_base.py index 9fa42b0ae6..95a0e3e748 100644 --- a/invokeai/app/services/latents_storage/latents_storage_base.py +++ b/invokeai/app/services/latents_storage/latents_storage_base.py @@ -1,10 +1,12 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Union import torch +from ..compel import ConditioningFieldData + class LatentsStorageBase(ABC): """Responsible for storing and retrieving latents.""" @@ -20,8 +22,10 @@ class LatentsStorageBase(ABC): def get(self, name: str) -> torch.Tensor: pass + # (LS) Added a Union with ConditioningFieldData to fix type mismatch errors in compel.py + # Not 100% sure this isn't an existing bug. @abstractmethod - def save(self, name: str, data: torch.Tensor) -> None: + def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None: pass @abstractmethod diff --git a/invokeai/app/services/latents_storage/latents_storage_disk.py b/invokeai/app/services/latents_storage/latents_storage_disk.py index 9192b9147f..ba6dbd3a28 100644 --- a/invokeai/app/services/latents_storage/latents_storage_disk.py +++ b/invokeai/app/services/latents_storage/latents_storage_disk.py @@ -7,6 +7,7 @@ import torch from invokeai.app.services.invoker import Invoker +from ..compel import ConditioningFieldData from .latents_storage_base import LatentsStorageBase @@ -27,7 +28,7 @@ class DiskLatentsStorage(LatentsStorageBase): latent_path = self.get_path(name) return torch.load(latent_path) - def save(self, name: str, data: torch.Tensor) -> None: + def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None: self.__output_folder.mkdir(parents=True, exist_ok=True) latent_path = self.get_path(name) torch.save(data, latent_path) diff --git a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py index 6232b76a27..1edda736a4 100644 --- a/invokeai/app/services/latents_storage/latents_storage_forward_cache.py +++ b/invokeai/app/services/latents_storage/latents_storage_forward_cache.py @@ -1,12 +1,13 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) from queue import Queue -from typing import Dict, Optional +from typing import Dict, Optional, Union import torch from invokeai.app.services.invoker import Invoker +from ..compel import ConditioningFieldData from .latents_storage_base import LatentsStorageBase @@ -46,7 +47,9 @@ class ForwardCacheLatentsStorage(LatentsStorageBase): self.__set_cache(name, latent) return latent - def save(self, name: str, data: torch.Tensor) -> None: + # TODO: (LS) ConditioningFieldData added as Union because of type-checking errors + # in compel.py. Unclear whether this is a long-standing bug, but seems to run. + def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None: self.__underlying_storage.save(name, data) self.__set_cache(name, data) self._on_changed(data) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 42e3c8f83a..e00dd4169d 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union from pydantic import BaseModel, Field +from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager import ( AnyModelConfig, @@ -19,6 +20,7 @@ from invokeai.backend.model_manager import ( ModelType, SubModelType, ) +from invokeai.backend.model_manager.load import AnyModelLoader from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore @@ -110,12 +112,45 @@ class ModelRecordServiceBase(ABC): pass @abstractmethod - def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + def load_model( + self, + key: str, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: """ Load the indicated model into memory and return a LoadedModel object. :param key: Key of model config to be fetched. - :param submodel_type: For main (pipeline models), the submodel to fetch + :param submodel: For main (pipeline models), the submodel to fetch + :param context: Invocation context, used for event issuing. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + """ + pass + + @abstractmethod + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Key of model config to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. Exceptions: UnknownModelException -- model with this key not known NotImplementedException -- a model loader was not provided at initialization time @@ -166,7 +201,7 @@ class ModelRecordServiceBase(ABC): @abstractmethod def exists(self, key: str) -> bool: """ - Return True if a model with the indicated key exists in the databse. + Return True if a model with the indicated key exists in the database. :param key: Unique key for the model to be deleted """ @@ -209,6 +244,12 @@ class ModelRecordServiceBase(ABC): """ pass + @property + @abstractmethod + def loader(self) -> Optional[AnyModelLoader]: + """Return the model loader used by this instance.""" + pass + def all_models(self) -> List[AnyModelConfig]: """Return all the model configs in the database.""" return self.search_by_attr() diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index b50cd17a75..28a77b1b1a 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -46,6 +46,8 @@ from math import ceil from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union +from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.backend.model_manager.config import ( AnyModelConfig, @@ -88,6 +90,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """Return the underlying database.""" return self._db + @property + def loader(self) -> Optional[AnyModelLoader]: + """Return the model loader used by this instance.""" + return self._loader + def add_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig: """ Add a model to the database. @@ -213,20 +220,73 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model - def load_model(self, key: str, submodel_type: Optional[SubModelType]) -> LoadedModel: + def load_model( + self, + key: str, + submodel: Optional[SubModelType], + context: Optional[InvocationContext] = None, + ) -> LoadedModel: """ Load the indicated model into memory and return a LoadedModel object. :param key: Key of model config to be fetched. - :param submodel_type: For main (pipeline models), the submodel to fetch. + :param submodel: For main (pipeline models), the submodel to fetch. + :param context: Invocation context used for event reporting Exceptions: UnknownModelException -- model with this key not known NotImplementedException -- a model loader was not provided at initialization time """ if not self._loader: raise NotImplementedError(f"Class {self.__class__} was not initialized with a model loader") + # we can emit model loading events if we are executing with access to the invocation context + model_config = self.get_model(key) - return self._loader.load_model(model_config, submodel_type) + if context: + self._emit_load_event( + context=context, + model_config=model_config, + ) + loaded_model = self._loader.load_model(model_config, submodel) + if context: + self._emit_load_event( + context=context, + model_config=model_config, + loaded=True, + ) + return loaded_model + + def load_model_by_attr( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + context: Optional[InvocationContext] = None, + ) -> LoadedModel: + """ + Load the indicated model into memory and return a LoadedModel object. + + This is provided for API compatability with the get_model() method + in the original model manager. However, note that LoadedModel is + not the same as the original ModelInfo that ws returned. + + :param model_name: Key of model config to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + :param context: The invocation context. + + Exceptions: UnknownModelException -- model with this key not known + NotImplementedException -- a model loader was not provided at initialization time + ValueError -- more than one model matches this combination + """ + configs = self.search_by_attr(model_name, base_model, model_type) + if len(configs) == 0: + raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model") + elif len(configs) > 1: + raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") + else: + return self.load_model(configs[0].key, submodel) def exists(self, key: str) -> bool: """ @@ -416,3 +476,29 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): return PaginatedResults( page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items ) + + def _emit_load_event( + self, + context: InvocationContext, + model_config: AnyModelConfig, + loaded: Optional[bool] = False, + ) -> None: + if context.services.queue.is_canceled(context.graph_execution_state_id): + raise CanceledException() + + if not loaded: + context.services.events.emit_model_load_started( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) + else: + context.services.events.emit_model_load_completed( + queue_id=context.queue_id, + queue_item_id=context.queue_item_id, + queue_batch_id=context.queue_batch_id, + graph_execution_state_id=context.graph_execution_state_id, + model_config=model_config, + ) diff --git a/invokeai/backend/embeddings/__init__.py b/invokeai/backend/embeddings/__init__.py new file mode 100644 index 0000000000..46ead533c4 --- /dev/null +++ b/invokeai/backend/embeddings/__init__.py @@ -0,0 +1,4 @@ +"""Initialization file for invokeai.backend.embeddings modules.""" + +# from .model_patcher import ModelPatcher +# __all__ = ["ModelPatcher"] diff --git a/invokeai/backend/embeddings/embedding_base.py b/invokeai/backend/embeddings/embedding_base.py new file mode 100644 index 0000000000..5e752a29e1 --- /dev/null +++ b/invokeai/backend/embeddings/embedding_base.py @@ -0,0 +1,12 @@ +"""Base class for LoRA and Textual Inversion models. + +The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw, +and is used for type checking of calls to the model patcher. + +The use of "Raw" here is a historical artifact, and carried forward in +order to avoid confusion. +""" + + +class EmbeddingModelRaw: + """Base class for LoRA and Textual Inversion models.""" diff --git a/invokeai/backend/embeddings/lora.py b/invokeai/backend/embeddings/lora.py index 9a59a97708..3c7ef074ef 100644 --- a/invokeai/backend/embeddings/lora.py +++ b/invokeai/backend/embeddings/lora.py @@ -11,6 +11,8 @@ from typing_extensions import Self from invokeai.backend.model_manager import BaseModelType +from .embedding_base import EmbeddingModelRaw + class LoRALayerBase: # rank: Optional[int] @@ -317,7 +319,7 @@ class FullLayer(LoRALayerBase): self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - ): + ) -> None: super().to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype) @@ -367,7 +369,7 @@ AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix -class LoRAModelRaw: # (torch.nn.Module): +class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module): _name: str layers: Dict[str, AnyLoRALayer] @@ -471,16 +473,16 @@ class LoRAModelRaw: # (torch.nn.Module): file_path = Path(file_path) model = cls( - name=file_path.stem, # TODO: + name=file_path.stem, layers={}, ) if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + sd = load_file(file_path.absolute().as_posix(), device="cpu") else: - state_dict = torch.load(file_path, map_location="cpu") + sd = torch.load(file_path, map_location="cpu") - state_dict = cls._group_state(state_dict) + state_dict = cls._group_state(sd) if base_model == BaseModelType.StableDiffusionXL: state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py index 6d73235197..4725181b8e 100644 --- a/invokeai/backend/embeddings/model_patcher.py +++ b/invokeai/backend/embeddings/model_patcher.py @@ -4,22 +4,20 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from pathlib import Path -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple import numpy as np import torch -from compel.embeddings_provider import BaseTextualInversionManager -from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel -from safetensors.torch import load_file +from diffusers import OnnxRuntimeModel, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer -from typing_extensions import Self from invokeai.app.shared.models import FreeUConfig +from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from .lora import LoRAModelRaw +from .textual_inversion import TextualInversionManager, TextualInversionModelRaw """ loras = [ @@ -67,7 +65,7 @@ class ModelPatcher: cls, unet: UNet2DConditionModel, loras: List[Tuple[LoRAModelRaw, float]], - ) -> Generator[None, None, None]: + ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -76,8 +74,8 @@ class ModelPatcher: def apply_lora_text_encoder( cls, text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModelRaw, float]], - ): + loras: Iterator[Tuple[LoRAModelRaw, float]], + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -87,7 +85,7 @@ class ModelPatcher: cls, text_encoder: CLIPTextModel, loras: List[Tuple[LoRAModelRaw, float]], - ): + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te1_"): yield @@ -97,7 +95,7 @@ class ModelPatcher: cls, text_encoder: CLIPTextModel, loras: List[Tuple[LoRAModelRaw, float]], - ): + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te2_"): yield @@ -105,10 +103,10 @@ class ModelPatcher: @contextmanager def apply_lora( cls, - model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel], - loras: List[Tuple[LoRAModelRaw, float]], + model: AnyModel, + loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, - ) -> Generator[None, None, None]: + ) -> None: original_weights = {} try: with torch.no_grad(): @@ -125,6 +123,7 @@ class ModelPatcher: # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA # weights to have valid keys. + assert isinstance(model, torch.nn.Module) module_key, module = cls._resolve_lora_key(model, layer_key, prefix) # All of the LoRA weight calculations will be done on the same device as the module weight. @@ -170,8 +169,8 @@ class ModelPatcher: cls, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, - ti_list: List[Tuple[str, TextualInversionModel]], - ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + ti_list: List[Tuple[str, TextualInversionModelRaw]], + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: init_tokens_count = None new_tokens_added = None @@ -201,7 +200,7 @@ class ModelPatcher: trigger += f"-!pad-{i}" return f"<{trigger}>" - def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor: + def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor: # for SDXL models, select the embedding that matches the text encoder's dimensions if ti.embedding_2 is not None: return ( @@ -229,6 +228,7 @@ class ModelPatcher: model_embeddings = text_encoder.get_input_embeddings() for ti_name, ti in ti_list: + assert isinstance(ti, TextualInversionModelRaw) ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) ti_tokens = [] @@ -267,7 +267,7 @@ class ModelPatcher: cls, text_encoder: CLIPTextModel, clip_skip: int, - ) -> Generator[None, None, None]: + ) -> None: skipped_layers = [] try: for _i in range(clip_skip): @@ -285,7 +285,7 @@ class ModelPatcher: cls, unet: UNet2DConditionModel, freeu_config: Optional[FreeUConfig] = None, - ) -> Generator[None, None, None]: + ) -> None: did_apply_freeu = False try: assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? @@ -301,94 +301,6 @@ class ModelPatcher: unet.disable_freeu() -class TextualInversionModel: - embedding: torch.Tensor # [n, 768]|[n, 1280] - embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> Self: - if not isinstance(file_path, Path): - file_path = Path(file_path) - - result = cls() # TODO: - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - # both v1 and v2 format embeddings - # difference mostly in metadata - if "string_to_param" in state_dict: - if len(state_dict["string_to_param"]) > 1: - print( - f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', - " token will be used.", - ) - - result.embedding = next(iter(state_dict["string_to_param"].values())) - - # v3 (easynegative) - elif "emb_params" in state_dict: - result.embedding = state_dict["emb_params"] - - # v5(sdxl safetensors file) - elif "clip_g" in state_dict and "clip_l" in state_dict: - result.embedding = state_dict["clip_g"] - result.embedding_2 = state_dict["clip_l"] - - # v4(diffusers bin files) - else: - result.embedding = next(iter(state_dict.values())) - - if len(result.embedding.shape) == 1: - result.embedding = result.embedding.unsqueeze(0) - - if not isinstance(result.embedding, torch.Tensor): - raise ValueError(f"Invalid embeddings file: {file_path.name}") - - return result - - -# no type hints for BaseTextualInversionManager? -class TextualInversionManager(BaseTextualInversionManager): # type: ignore - pad_tokens: Dict[int, List[int]] - tokenizer: CLIPTokenizer - - def __init__(self, tokenizer: CLIPTokenizer): - self.pad_tokens = {} - self.tokenizer = tokenizer - - def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: - if len(self.pad_tokens) == 0: - return token_ids - - if token_ids[0] == self.tokenizer.bos_token_id: - raise ValueError("token_ids must not start with bos_token_id") - if token_ids[-1] == self.tokenizer.eos_token_id: - raise ValueError("token_ids must not end with eos_token_id") - - new_token_ids = [] - for token_id in token_ids: - new_token_ids.append(token_id) - if token_id in self.pad_tokens: - new_token_ids.extend(self.pad_tokens[token_id]) - - # Do not exceed the max model input size - # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), - # which first removes and then adds back the start and end tokens. - max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 - if len(new_token_ids) > max_length: - new_token_ids = new_token_ids[0:max_length] - - return new_token_ids - - class ONNXModelPatcher: @classmethod @contextmanager @@ -396,7 +308,7 @@ class ONNXModelPatcher: cls, unet: OnnxRuntimeModel, loras: List[Tuple[LoRAModelRaw, float]], - ) -> Generator[None, None, None]: + ) -> None: with cls.apply_lora(unet, loras, "lora_unet_"): yield @@ -406,7 +318,7 @@ class ONNXModelPatcher: cls, text_encoder: OnnxRuntimeModel, loras: List[Tuple[LoRAModelRaw, float]], - ) -> Generator[None, None, None]: + ) -> None: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @@ -419,7 +331,7 @@ class ONNXModelPatcher: model: IAIOnnxRuntimeModel, loras: List[Tuple[LoRAModelRaw, float]], prefix: str, - ) -> Generator[None, None, None]: + ) -> None: from .models.base import IAIOnnxRuntimeModel if not isinstance(model, IAIOnnxRuntimeModel): @@ -506,7 +418,7 @@ class ONNXModelPatcher: tokenizer: CLIPTokenizer, text_encoder: IAIOnnxRuntimeModel, ti_list: List[Tuple[str, Any]], - ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: from .models.base import IAIOnnxRuntimeModel if not isinstance(text_encoder, IAIOnnxRuntimeModel): diff --git a/invokeai/backend/embeddings/textual_inversion.py b/invokeai/backend/embeddings/textual_inversion.py new file mode 100644 index 0000000000..389edff039 --- /dev/null +++ b/invokeai/backend/embeddings/textual_inversion.py @@ -0,0 +1,100 @@ +"""Textual Inversion wrapper class.""" + +from pathlib import Path +from typing import Dict, List, Optional, Union + +import torch +from compel.embeddings_provider import BaseTextualInversionManager +from safetensors.torch import load_file +from transformers import CLIPTokenizer +from typing_extensions import Self + +from .embedding_base import EmbeddingModelRaw + + +class TextualInversionModelRaw(EmbeddingModelRaw): + embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not isinstance(file_path, Path): + file_path = Path(file_path) + + result = cls() # TODO: + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + # both v1 and v2 format embeddings + # difference mostly in metadata + if "string_to_param" in state_dict: + if len(state_dict["string_to_param"]) > 1: + print( + f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', + " token will be used.", + ) + + result.embedding = next(iter(state_dict["string_to_param"].values())) + + # v3 (easynegative) + elif "emb_params" in state_dict: + result.embedding = state_dict["emb_params"] + + # v5(sdxl safetensors file) + elif "clip_g" in state_dict and "clip_l" in state_dict: + result.embedding = state_dict["clip_g"] + result.embedding_2 = state_dict["clip_l"] + + # v4(diffusers bin files) + else: + result.embedding = next(iter(state_dict.values())) + + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + + if not isinstance(result.embedding, torch.Tensor): + raise ValueError(f"Invalid embeddings file: {file_path.name}") + + return result + + +# no type hints for BaseTextualInversionManager? +class TextualInversionManager(BaseTextualInversionManager): # type: ignore + pad_tokens: Dict[int, List[int]] + tokenizer: CLIPTokenizer + + def __init__(self, tokenizer: CLIPTokenizer): + self.pad_tokens = {} + self.tokenizer = tokenizer + + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + if len(self.pad_tokens) == 0: + return token_ids + + if token_ids[0] == self.tokenizer.bos_token_id: + raise ValueError("token_ids must not start with bos_token_id") + if token_ids[-1] == self.tokenizer.eos_token_id: + raise ValueError("token_ids must not end with eos_token_id") + + new_token_ids = [] + for token_id in token_ids: + new_token_ids.append(token_id) + if token_id in self.pad_tokens: + new_token_ids.extend(self.pad_tokens[token_id]) + + # Do not exceed the max model input size + # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), + # which first removes and then adds back the start and end tokens. + max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 + if len(new_token_ids) > max_length: + new_token_ids = new_token_ids[0:max_length] + + return new_token_ids diff --git a/invokeai/backend/install/install_helper.py b/invokeai/backend/install/install_helper.py index 57dfadcaea..8877e33092 100644 --- a/invokeai/backend/install/install_helper.py +++ b/invokeai/backend/install/install_helper.py @@ -241,10 +241,11 @@ class InstallHelper(object): if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url): repo_id = match.group(1) repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None + subfolder = Path(model_info.subfolder) if model_info.subfolder else None return HFModelSource( repo_id=repo_id, access_token=HfFolder.get_token(), - subfolder=model_info.subfolder, + subfolder=subfolder, variant=repo_variant, ) if re.match(r"^(http|https):", model_path_id_or_url): diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 49ce6af2b8..0dcd925c84 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -30,8 +30,11 @@ from typing_extensions import Annotated, Any, Dict from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from ..embeddings.embedding_base import EmbeddingModelRaw from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus +AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw] + class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" @@ -299,7 +302,7 @@ AnyModelConfig = Union[ ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) -AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus] + # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index ee9d6d53e3..9d98ee3053 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -18,8 +18,8 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple, Type from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.util.logging import InvokeAILogger diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 757745072d..2192c88ac2 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -19,7 +19,7 @@ from invokeai.backend.model_manager import ( ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase -from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase +from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats, ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -71,7 +71,7 @@ class ModelLoader(ModelLoaderBase): model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type) if not model_path.exists(): - raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}") + raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}") model_path = self._convert_if_needed(model_config, model_path, submodel_type) locker = self._load_if_needed(model_config, model_path, submodel_type) diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 0cb5184f3a..32c682d042 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -1,4 +1,6 @@ """Init file for ModelCache.""" +from .model_cache_base import ModelCacheBase, CacheStats # noqa F401 +from .model_cache_default import ModelCache # noqa F401 -_all__ = ["ModelCacheBase", "ModelCache"] +_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"] diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index b1a6768ee8..4a4a3c7d29 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -8,13 +8,13 @@ model will be cleared and (re)loaded from disk when next needed. """ from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from logging import Logger -from typing import Generic, Optional, TypeVar +from typing import Dict, Generic, Optional, TypeVar import torch -from invokeai.backend.model_manager import AnyModel, SubModelType +from invokeai.backend.model_manager.config import AnyModel, SubModelType class ModelLockerBase(ABC): @@ -65,6 +65,19 @@ class CacheRecord(Generic[T]): return self._locks > 0 +@dataclass +class CacheStats(object): + """Collect statistics on cache performance.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + class ModelCacheBase(ABC, Generic[T]): """Virtual base class for RAM model cache.""" @@ -98,10 +111,22 @@ class ModelCacheBase(ABC, Generic[T]): pass @abstractmethod - def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None: + def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None: """Move model into the indicated device.""" pass + @property + @abstractmethod + def stats(self) -> CacheStats: + """Return collected CacheStats object.""" + pass + + @stats.setter + @abstractmethod + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + pass + @property @abstractmethod def logger(self) -> Logger: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 7e30512a58..b1deb215b2 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -24,19 +24,17 @@ import math import sys import time from contextlib import suppress -from dataclasses import dataclass, field from logging import Logger from typing import Dict, List, Optional import torch -from invokeai.backend.model_manager import SubModelType -from invokeai.backend.model_manager.load.load_base import AnyModel +from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.logging import InvokeAILogger -from .model_cache_base import CacheRecord, ModelCacheBase +from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase from .model_locker import ModelLocker, ModelLockerBase if choose_torch_device() == torch.device("mps"): @@ -56,20 +54,6 @@ GIG = 1073741824 MB = 2**20 -@dataclass -class CacheStats(object): - """Collect statistics on cache performance.""" - - hits: int = 0 # cache hits - misses: int = 0 # cache misses - high_watermark: int = 0 # amount of cache used - in_cache: int = 0 # number of models in cache - cleared: int = 0 # number of models cleared to make space - cache_size: int = 0 # total size of cache - # {submodel_key => size} - loaded_model_sizes: Dict[str, int] = field(default_factory=dict) - - class ModelCache(ModelCacheBase[AnyModel]): """Implementation of ModelCacheBase.""" @@ -110,7 +94,7 @@ class ModelCache(ModelCacheBase[AnyModel]): self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG # used for stats collection - self.stats = CacheStats() + self._stats: Optional[CacheStats] = None self._cached_models: Dict[str, CacheRecord[AnyModel]] = {} self._cache_stack: List[str] = [] @@ -140,6 +124,16 @@ class ModelCache(ModelCacheBase[AnyModel]): """Return the cap on cache size.""" return self._max_cache_size + @property + def stats(self) -> Optional[CacheStats]: + """Return collected CacheStats object.""" + return self._stats + + @stats.setter + def stats(self, stats: CacheStats) -> None: + """Set the CacheStats object for collectin cache statistics.""" + self._stats = stats + def cache_size(self) -> int: """Get the total size of the models currently cached.""" total = 0 @@ -189,21 +183,24 @@ class ModelCache(ModelCacheBase[AnyModel]): """ key = self._make_cache_key(key, submodel_type) if key in self._cached_models: - self.stats.hits += 1 + if self.stats: + self.stats.hits += 1 else: - self.stats.misses += 1 + if self.stats: + self.stats.misses += 1 raise IndexError(f"The model with key {key} is not in the cache.") cache_entry = self._cached_models[key] # more stats - stats_name = stats_name or key - self.stats.cache_size = int(self._max_cache_size * GIG) - self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) - self.stats.in_cache = len(self._cached_models) - self.stats.loaded_model_sizes[stats_name] = max( - self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size - ) + if self.stats: + stats_name = stats_name or key + self.stats.cache_size = int(self._max_cache_size * GIG) + self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size()) + self.stats.in_cache = len(self._cached_models) + self.stats.loaded_model_sizes[stats_name] = max( + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + ) # this moves the entry to the top (right end) of the stack with suppress(Exception): diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py index 394fddc75d..6635f6b43f 100644 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional, Tuple -from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw +from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 212045f81b..75e6aa0a5d 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -4,3 +4,12 @@ Initialization file for the invokeai.backend.stable_diffusion package from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401 from .diffusion import InvokeAIDiffuserComponent # noqa: F401 from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401 +from .seamless import set_seamless # noqa: F401 + +__all__ = [ + "PipelineIntermediateState", + "StableDiffusionGeneratorPipeline", + "InvokeAIDiffuserComponent", + "AttentionMapSaver", + "set_seamless", +] diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py new file mode 100644 index 0000000000..bfdf9e0c53 --- /dev/null +++ b/invokeai/backend/stable_diffusion/seamless.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import List, Union + +import torch.nn as nn +from diffusers.models import AutoencoderKL, UNet2DConditionModel + + +def _conv_forward_asymmetric(self, input, weight, bias): + """ + Patch for Conv2d._conv_forward that supports asymmetric padding + """ + working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) + working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) + return nn.functional.conv2d( + working, + weight, + bias, + self.stride, + nn.modules.utils._pair(0), + self.dilation, + self.groups, + ) + + +@contextmanager +def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]): + try: + to_restore = [] + + for m_name, m in model.named_modules(): + if isinstance(model, UNet2DConditionModel): + if ".attentions." in m_name: + continue + + if ".resnets." in m_name: + if ".conv2" in m_name: + continue + if ".conv_shortcut" in m_name: + continue + + """ + if isinstance(model, UNet2DConditionModel): + if False and ".upsamplers." in m_name: + continue + + if False and ".downsamplers." in m_name: + continue + + if True and ".resnets." in m_name: + if True and ".conv1" in m_name: + if False and "down_blocks" in m_name: + continue + if False and "mid_block" in m_name: + continue + if False and "up_blocks" in m_name: + continue + + if True and ".conv2" in m_name: + continue + + if True and ".conv_shortcut" in m_name: + continue + + if True and ".attentions." in m_name: + continue + + if False and m_name in ["conv_in", "conv_out"]: + continue + """ + + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + m.asymmetric_padding_mode = {} + m.asymmetric_padding = {} + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" + m.asymmetric_padding["x"] = ( + m._reversed_padding_repeated_twice[0], + m._reversed_padding_repeated_twice[1], + 0, + 0, + ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" + m.asymmetric_padding["y"] = ( + 0, + 0, + m._reversed_padding_repeated_twice[2], + m._reversed_padding_repeated_twice[3], + ) + + to_restore.append((m, m._conv_forward)) + m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d) + + yield + + finally: + for module, orig_conv_forward in to_restore: + module._conv_forward = orig_conv_forward + if hasattr(module, "asymmetric_padding_mode"): + del module.asymmetric_padding_mode + if hasattr(module, "asymmetric_padding"): + del module.asymmetric_padding diff --git a/invokeai/backend/util/silence_warnings.py b/invokeai/backend/util/silence_warnings.py new file mode 100644 index 0000000000..068b605da9 --- /dev/null +++ b/invokeai/backend/util/silence_warnings.py @@ -0,0 +1,28 @@ +"""Context class to silence transformers and diffusers warnings.""" +import warnings +from typing import Any + +from diffusers import logging as diffusers_logging +from transformers import logging as transformers_logging + + +class SilenceWarnings(object): + """Use in context to temporarily turn off warnings from transformers & diffusers modules. + + with SilenceWarnings(): + # do something + """ + + def __init__(self) -> None: + self.transformers_verbosity = transformers_logging.get_verbosity() + self.diffusers_verbosity = diffusers_logging.get_verbosity() + + def __enter__(self) -> None: + transformers_logging.set_verbosity_error() + diffusers_logging.set_verbosity_error() + warnings.simplefilter("ignore") + + def __exit__(self, *args: Any) -> None: + transformers_logging.set_verbosity(self.transformers_verbosity) + diffusers_logging.set_verbosity(self.diffusers_verbosity) + warnings.simplefilter("default") diff --git a/invokeai/frontend/install/model_install2.py b/invokeai/frontend/install/model_install2.py index 51a633a565..22b132370e 100644 --- a/invokeai/frontend/install/model_install2.py +++ b/invokeai/frontend/install/model_install2.py @@ -23,7 +23,7 @@ import torch from npyscreen import widget from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.model_install import ModelInstallService +from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.backend.install.install_helper import InstallHelper, InstallSelections, UnifiedModelInfo from invokeai.backend.model_manager import ModelType from invokeai.backend.util import choose_precision, choose_torch_device @@ -499,7 +499,7 @@ class AddModelApplication(npyscreen.NPSAppManaged): # type: ignore ) -def list_models(installer: ModelInstallService, model_type: ModelType): +def list_models(installer: ModelInstallServiceBase, model_type: ModelType): """Print out all models of type model_type.""" models = installer.record_store.search_by_attr(model_type=model_type) print(f"Installed models of type `{model_type}`:") @@ -527,7 +527,9 @@ def select_and_download_models(opt: Namespace) -> None: install_helper.add_or_delete(selections) elif opt.default_only: - selections = InstallSelections(install_models=[install_helper.default_model()]) + default_model = install_helper.default_model() + assert default_model is not None + selections = InstallSelections(install_models=[default_model]) install_helper.add_or_delete(selections) elif opt.yes_to_all: diff --git a/tests/backend/model_manager_2/util/test_hf_model_select.py b/tests/backend/model_manager_2/util/test_hf_model_select.py index f14d9a6823..5bef9cb2e1 100644 --- a/tests/backend/model_manager_2/util/test_hf_model_select.py +++ b/tests/backend/model_manager_2/util/test_hf_model_select.py @@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]: "text_encoder/model.onnx", "text_encoder_2/config.json", "text_encoder_2/model.onnx", + "text_encoder_2/model.onnx_data", "tokenizer/merges.txt", "tokenizer/special_tokens_map.json", "tokenizer/tokenizer_config.json", @@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]: "tokenizer_2/vocab.json", "unet/config.json", "unet/model.onnx", + "unet/model.onnx_data", "vae_decoder/config.json", "vae_decoder/model.onnx", "vae_encoder/config.json", diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index aacae06a8b..be823e2be9 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -2,7 +2,7 @@ from pathlib import Path import pytest -from invokeai.backend import BaseModelType +from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant from invokeai.backend.model_manager.probe import VaeFolderProbe @@ -21,10 +21,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat base_type = probe.get_base_type() assert base_type == expected_type repo_variant = probe.get_repo_variant() - assert repo_variant == "default" + assert repo_variant == ModelRepoVariant.DEFAULT def test_repo_variant(datadir: Path): probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") repo_variant = probe.get_repo_variant() - assert repo_variant == "fp16" + assert repo_variant == ModelRepoVariant.FP16