From 528ac5dd2592ceaf20eb216ae32c7eae85d09a81 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 6 Mar 2024 19:37:15 +1100 Subject: [PATCH] refactor(nodes): model identifiers - All models are identified by a key and optionally a submodel type via new model `ModelField`. Previously, a few model types had their own class, but not all of them. This inconsistency just added complexity without any benefit. - Update all invocation to use the new format. - In the node API, models are loaded by key or an instance of `ModelField` as a convenience. - Add an enriched model schema for metadata. It includes key, hash, name, base and type. --- invokeai/app/invocations/compel.py | 12 +- .../controlnet_image_processors.py | 11 +- invokeai/app/invocations/fields.py | 2 +- invokeai/app/invocations/ip_adapter.py | 22 +-- invokeai/app/invocations/latent.py | 32 ++-- invokeai/app/invocations/metadata.py | 69 ++++++-- invokeai/app/invocations/model.py | 150 ++++++------------ invokeai/app/invocations/sdxl.py | 101 +++--------- invokeai/app/invocations/t2i_adapter.py | 9 +- .../model_manager/model_manager_default.py | 1 - .../model_records/model_records_base.py | 11 ++ .../model_records/model_records_sql.py | 15 ++ .../app/services/shared/invocation_context.py | 62 ++++---- invokeai/app/util/ti_utils.py | 2 +- invokeai/invocation_api/__init__.py | 18 +-- 15 files changed, 229 insertions(+), 288 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index ff13658052..a2e0bd06c4 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + tokenizer_info = context.models.load(self.clip.tokenizer) tokenizer_model = tokenizer_info.model assert isinstance(tokenizer_model, CLIPTokenizer) - text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) + text_encoder_info = context.models.load(self.clip.text_encoder) text_encoder_model = text_encoder_info.model assert isinstance(text_encoder_model, CLIPTextModel) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) + lora_info = context.models.load(lora.lora) assert isinstance(lora_info.model, LoRAModelRaw) yield (lora_info.model, lora.weight) del lora_info @@ -133,10 +133,10 @@ class SDXLPromptInvocationBase: lora_prefix: str, zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: - tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + tokenizer_info = context.models.load(clip_field.tokenizer) tokenizer_model = tokenizer_info.model assert isinstance(tokenizer_model, CLIPTokenizer) - text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) + text_encoder_info = context.models.load(clip_field.text_encoder) text_encoder_model = text_encoder_info.model assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection)) @@ -163,7 +163,7 @@ class SDXLPromptInvocationBase: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) + lora_info = context.models.load(lora.lora) lora_model = lora_info.model assert isinstance(lora_model, LoRAModelRaw) yield (lora_model, lora.weight) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 9eba3acdca..fb070df69c 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -34,6 +34,7 @@ from invokeai.app.invocations.fields import ( WithBoard, WithMetadata, ) +from invokeai.app.invocations.model import ModelField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext @@ -51,15 +52,9 @@ CONTROLNET_RESIZE_VALUES = Literal[ ] -class ControlNetModelField(BaseModel): - """ControlNet model field""" - - key: str = Field(description="Model config record key for the ControlNet model") - - class ControlField(BaseModel): image: ImageField = Field(description="The control image") - control_model: ControlNetModelField = Field(description="The ControlNet model to use") + control_model: ModelField = Field(description="The ControlNet model to use") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" @@ -95,7 +90,7 @@ class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" image: ImageField = InputField(description="The control image") - control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct) + control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct) control_weight: Union[float, List[float]] = InputField( default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 712ab415b0..f19c518d0f 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -228,7 +228,7 @@ class ConditioningField(BaseModel): # endregion -class MetadataField(RootModel): +class MetadataField(RootModel[dict[str, Any]]): """ Pydantic model for metadata with custom root of type dict[str, Any]. Metadata is stored without a strict schema. diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index bebdc29b86..58ed5166e6 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -11,25 +11,17 @@ from invokeai.app.invocations.baseinvocation import ( invocation_output, ) from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.invocations.model import ModelField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager.config import BaseModelType, ModelType - - -# LS: Consider moving these two classes into model.py -class IPAdapterModelField(BaseModel): - key: str = Field(description="Key to the IP-Adapter model") - - -class CLIPVisionModelField(BaseModel): - key: str = Field(description="Key to the CLIP Vision image encoder model") +from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType class IPAdapterField(BaseModel): image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).") - ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.") - image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.") + ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.") + image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.") weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)" @@ -62,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation): # Inputs image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).") - ip_adapter_model: IPAdapterModelField = InputField( + ip_adapter_model: ModelField = InputField( description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1 ) @@ -90,18 +82,18 @@ class IPAdapterInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) + assert isinstance(ip_adapter_info, IPAdapterConfig) image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_models = context.models.search_by_attrs( name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision ) assert len(image_encoder_models) == 1 - image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key) return IPAdapterOutput( ip_adapter=IPAdapterField( image=self.image, ip_adapter_model=self.ip_adapter_model, - image_encoder_model=image_encoder_model, + image_encoder_model=ModelField(key=image_encoder_models[0].key), weight=self.weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 5931b1b8f7..c3704b2ed8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -26,6 +26,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler from PIL import Image, ImageFilter from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize +from transformers import CLIPVisionModelWithProjection from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.fields import ( @@ -75,7 +76,7 @@ from .baseinvocation import ( invocation_output, ) from .controlnet_image_processors import ControlField -from .model import ModelInfo, UNetField, VaeField +from .model import ModelField, UNetField, VaeField if choose_torch_device() == torch.device("mps"): from torch import mps @@ -153,7 +154,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ) if image_tensor is not None: - vae_info = context.models.load(**self.vae.vae.model_dump()) + vae_info = context.models.load(self.vae.vae) img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) @@ -244,12 +245,12 @@ class CreateGradientMaskInvocation(BaseInvocation): def get_scheduler( context: InvocationContext, - scheduler_info: ModelInfo, + scheduler_info: ModelField, scheduler_name: str, seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) + orig_scheduler_info = context.models.load(scheduler_info) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -461,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation): # and if weight is None, populate with default 1.0? controlnet_data = [] for control_info in control_list: - control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key)) + control_model = exit_stack.enter_context(context.models.load(control_info.control_model)) # control_models.append(control_model) control_image_field = control_info.image @@ -523,11 +524,10 @@ class DenoiseLatentsInvocation(BaseInvocation): conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.models.load(key=single_ip_adapter.ip_adapter_model.key) + context.models.load(single_ip_adapter.ip_adapter_model) ) - image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key) - + image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. single_ipa_image_fields = single_ip_adapter.image if not isinstance(single_ipa_image_fields, list): @@ -538,6 +538,7 @@ class DenoiseLatentsInvocation(BaseInvocation): # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. with image_encoder_model_info as image_encoder_model: + assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) # Get image embeddings from CLIP and ImageProjModel. image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( single_ipa_images, image_encoder_model @@ -577,8 +578,8 @@ class DenoiseLatentsInvocation(BaseInvocation): t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key) - t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key) + t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key) + t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model) image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. @@ -731,12 +732,13 @@ class DenoiseLatentsInvocation(BaseInvocation): def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) + lora_info = context.models.load(lora.lora) + assert isinstance(lora_info.model, LoRAModelRaw) yield (lora_info.model, lora.weight) del lora_info return - unet_info = context.models.load(**self.unet.unet.model_dump()) + unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, @@ -841,8 +843,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) - vae_info = context.models.load(**self.vae.vae.model_dump()) - + vae_info = context.models.load(self.vae.vae) + assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL)) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: assert isinstance(vae, torch.nn.Module) latents = latents.to(vae.device) @@ -1064,7 +1066,7 @@ class ImageToLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) - vae_info = context.models.load(**self.vae.vae.model_dump()) + vae_info = context.models.load(self.vae.vae) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index bec1b0d9d5..7161caf260 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -8,7 +8,10 @@ from invokeai.app.invocations.baseinvocation import ( invocation, invocation_output, ) -from invokeai.app.invocations.controlnet_image_processors import ControlField +from invokeai.app.invocations.controlnet_image_processors import ( + CONTROLNET_MODE_VALUES, + CONTROLNET_RESIZE_VALUES, +) from invokeai.app.invocations.fields import ( FieldDescriptions, ImageField, @@ -17,10 +20,8 @@ from invokeai.app.invocations.fields import ( OutputField, UIType, ) -from invokeai.app.invocations.ip_adapter import IPAdapterModelField -from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField -from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.config import BaseModelType, ModelType from ...version import __version__ @@ -30,10 +31,20 @@ class MetadataItemField(BaseModel): value: Any = Field(description=FieldDescriptions.metadata_item_value) +class ModelMetadataField(BaseModel): + """Model Metadata Field""" + + key: str + hash: str + name: str + base: BaseModelType + type: ModelType + + class LoRAMetadataField(BaseModel): """LoRA Metadata Field""" - model: LoRAModelField = Field(description=FieldDescriptions.lora_model) + model: ModelMetadataField = Field(description=FieldDescriptions.lora_model) weight: float = Field(description=FieldDescriptions.lora_weight) @@ -41,7 +52,7 @@ class IPAdapterMetadataField(BaseModel): """IP Adapter Field, minus the CLIP Vision Encoder model""" image: ImageField = Field(description="The IP-Adapter image prompt.") - ip_adapter_model: IPAdapterModelField = Field( + ip_adapter_model: ModelMetadataField = Field( description="The IP-Adapter model.", ) weight: Union[float, list[float]] = Field( @@ -51,6 +62,33 @@ class IPAdapterMetadataField(BaseModel): end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)") +class T2IAdapterMetadataField(BaseModel): + image: ImageField = Field(description="The T2I-Adapter image prompt.") + t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.") + weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter") + begin_step_percent: float = Field( + default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)" + ) + end_step_percent: float = Field( + default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)" + ) + resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") + + +class ControlNetMetadataField(BaseModel): + image: ImageField = Field(description="The control image") + control_model: ModelMetadataField = Field(description="The ControlNet model to use") + control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet") + begin_step_percent: float = Field( + default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" + ) + end_step_percent: float = Field( + default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" + ) + control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use") + resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") + + @invocation_output("metadata_item_output") class MetadataItemOutput(BaseInvocationOutput): """Metadata Item Output""" @@ -140,14 +178,14 @@ class CoreMetadataInvocation(BaseInvocation): default=None, description="The number of skipped CLIP layers", ) - model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference") - controlnets: Optional[list[ControlField]] = InputField( + model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference") + controlnets: Optional[list[ControlNetMetadataField]] = InputField( default=None, description="The ControlNets used for inference" ) ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField( default=None, description="The IP Adapters used for inference" ) - t2iAdapters: Optional[list[T2IAdapterField]] = InputField( + t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField( default=None, description="The IP Adapters used for inference" ) loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference") @@ -159,7 +197,7 @@ class CoreMetadataInvocation(BaseInvocation): default=None, description="The name of the initial image", ) - vae: Optional[VAEModelField] = InputField( + vae: Optional[ModelMetadataField] = InputField( default=None, description="The VAE used for decoding, if the main model's default was not used", ) @@ -190,7 +228,7 @@ class CoreMetadataInvocation(BaseInvocation): ) # SDXL Refiner - refiner_model: Optional[MainModelField] = InputField( + refiner_model: Optional[ModelMetadataField] = InputField( default=None, description="The SDXL Refiner model used", ) @@ -222,10 +260,9 @@ class CoreMetadataInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> MetadataOutput: """Collects and outputs a CoreMetadata object""" - return MetadataOutput( - metadata=MetadataField.model_validate( - self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"}) - ) - ) + as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"}) + as_dict["app_version"] = __version__ + + return MetadataOutput(metadata=MetadataField.model_validate(as_dict)) model_config = ConfigDict(extra="allow") diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index cb69558be5..648f34e749 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -6,8 +6,8 @@ from pydantic import BaseModel, Field from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig +from invokeai.backend.model_manager.config import SubModelType -from ...backend.model_manager import SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -16,33 +16,34 @@ from .baseinvocation import ( ) -class ModelInfo(BaseModel): - key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()") - submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel") +class ModelField(BaseModel): + key: str = Field(description="Key of the model") + submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None) -class LoraInfo(ModelInfo): - weight: float = Field(description="Lora's weight which to use when apply to model") +class LoRAField(BaseModel): + lora: ModelField = Field(description="Info to load lora model") + weight: float = Field(description="Weight to apply to lora model") class UNetField(BaseModel): - unet: ModelInfo = Field(description="Info to load unet submodel") - scheduler: ModelInfo = Field(description="Info to load scheduler submodel") - loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + unet: ModelField = Field(description="Info to load unet submodel") + scheduler: ModelField = Field(description="Info to load scheduler submodel") + loras: List[LoRAField] = Field(description="Loras to apply on model loading") seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration") class ClipField(BaseModel): - tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") - text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") + tokenizer: ModelField = Field(description="Info to load tokenizer submodel") + text_encoder: ModelField = Field(description="Info to load text_encoder submodel") skipped_layers: int = Field(description="Number of skipped layers in text_encoder") - loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + loras: List[LoRAField] = Field(description="Loras to apply on model loading") class VaeField(BaseModel): # TODO: better naming? - vae: ModelInfo = Field(description="Info to load vae submodel") + vae: ModelField = Field(description="Info to load vae submodel") seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless') @@ -74,18 +75,6 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): pass -class MainModelField(BaseModel): - """Main model field""" - - key: str = Field(description="Model key") - - -class LoRAModelField(BaseModel): - """LoRA model field""" - - key: str = Field(description="LoRA model key") - - @invocation( "main_model_loader", title="Main Model", @@ -96,46 +85,24 @@ class LoRAModelField(BaseModel): class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" - model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) + model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - key = self.model.key - # TODO: not found exceptions - if not context.models.exists(key): - raise Exception(f"Unknown model {key}") + if not context.models.exists(self.model.key): + raise Exception(f"Unknown model {self.model.key}") + + unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet}) + scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler}) + tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) return ModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - key=key, - submodel_type=SubModelType.UNet, - ), - scheduler=ModelInfo( - key=key, - submodel_type=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - key=key, - submodel_type=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - key=key, - submodel_type=SubModelType.TextEncoder, - ), - loras=[], - skipped_layers=0, - ), - vae=VaeField( - vae=ModelInfo( - key=key, - submodel_type=SubModelType.VAE, - ), - ), + unet=UNetField(unet=unet, scheduler=scheduler, loras=[]), + clip=ClipField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0), + vae=VaeField(vae=vae), ) @@ -151,7 +118,7 @@ class LoraLoaderOutput(BaseInvocationOutput): class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" - lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") + lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) unet: Optional[UNetField] = InputField( default=None, @@ -167,38 +134,33 @@ class LoraLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> LoraLoaderOutput: - if self.lora is None: - raise Exception("No LoRA provided") - lora_key = self.lora.key if not context.models.exists(lora_key): raise Exception(f"Unkown lora: {lora_key}!") - if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras): raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras): raise Exception(f'Lora "{lora_key}" already applied to clip') output = LoraLoaderOutput() if self.unet is not None: - output.unet = copy.deepcopy(self.unet) + output.unet = self.unet.model_copy(deep=True) output.unet.loras.append( - LoraInfo( - key=lora_key, - submodel_type=None, + LoRAField( + lora=self.lora, weight=self.weight, ) ) if self.clip is not None: - output.clip = copy.deepcopy(self.clip) + output.clip = self.clip.model_copy(deep=True) output.clip.loras.append( - LoraInfo( - key=lora_key, - submodel_type=None, + LoRAField( + lora=self.lora, weight=self.weight, ) ) @@ -225,7 +187,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput): class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" - lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") + lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) unet: Optional[UNetField] = InputField( default=None, @@ -247,51 +209,45 @@ class SDXLLoraLoaderInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: - if self.lora is None: - raise Exception("No LoRA provided") - lora_key = self.lora.key if not context.models.exists(lora_key): raise Exception(f"Unknown lora: {lora_key}!") - if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): + if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras): raise Exception(f'Lora "{lora_key}" already applied to unet') - if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras): + if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras): raise Exception(f'Lora "{lora_key}" already applied to clip') - if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras): + if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras): raise Exception(f'Lora "{lora_key}" already applied to clip2') output = SDXLLoraLoaderOutput() if self.unet is not None: - output.unet = copy.deepcopy(self.unet) + output.unet = self.unet.model_copy(deep=True) output.unet.loras.append( - LoraInfo( - key=lora_key, - submodel_type=None, + LoRAField( + lora=self.lora, weight=self.weight, ) ) if self.clip is not None: - output.clip = copy.deepcopy(self.clip) + output.clip = self.clip.model_copy(deep=True) output.clip.loras.append( - LoraInfo( - key=lora_key, - submodel_type=None, + LoRAField( + lora=self.lora, weight=self.weight, ) ) if self.clip2 is not None: - output.clip2 = copy.deepcopy(self.clip2) + output.clip2 = self.clip2.model_copy(deep=True) output.clip2.loras.append( - LoraInfo( - key=lora_key, - submodel_type=None, + LoRAField( + lora=self.lora, weight=self.weight, ) ) @@ -299,17 +255,11 @@ class SDXLLoraLoaderInvocation(BaseInvocation): return output -class VAEModelField(BaseModel): - """Vae model field""" - - key: str = Field(description="Model's key") - - @invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" - vae_model: VAEModelField = InputField( + vae_model: ModelField = InputField( description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", @@ -321,7 +271,7 @@ class VaeLoaderInvocation(BaseInvocation): if not context.models.exists(key): raise Exception(f"Unkown vae: {key}!") - return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) + return VAEOutput(vae=VaeField(vae=self.vae_model)) @invocation_output("seamless_output") diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 4e783defec..77c825a3eb 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -8,7 +8,7 @@ from .baseinvocation import ( invocation, invocation_output, ) -from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField +from .model import ClipField, ModelField, UNetField, VaeField @invocation_output("sdxl_model_loader_output") @@ -34,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" - model: MainModelField = InputField( + model: ModelField = InputField( description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel ) # TODO: precision? @@ -46,48 +46,19 @@ class SDXLModelLoaderInvocation(BaseInvocation): if not context.models.exists(model_key): raise Exception(f"Unknown model: {model_key}") + unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet}) + scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler}) + tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) + text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) + vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) + return SDXLModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - key=model_key, - submodel_type=SubModelType.UNet, - ), - scheduler=ModelInfo( - key=model_key, - submodel_type=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - key=model_key, - submodel_type=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - key=model_key, - submodel_type=SubModelType.TextEncoder, - ), - loras=[], - skipped_layers=0, - ), - clip2=ClipField( - tokenizer=ModelInfo( - key=model_key, - submodel_type=SubModelType.Tokenizer2, - ), - text_encoder=ModelInfo( - key=model_key, - submodel_type=SubModelType.TextEncoder2, - ), - loras=[], - skipped_layers=0, - ), - vae=VaeField( - vae=ModelInfo( - key=model_key, - submodel_type=SubModelType.VAE, - ), - ), + unet=UNetField(unet=unet, scheduler=scheduler, loras=[]), + clip=ClipField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0), + clip2=ClipField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0), + vae=VaeField(vae=vae), ) @@ -101,10 +72,8 @@ class SDXLModelLoaderInvocation(BaseInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" - model: MainModelField = InputField( - description=FieldDescriptions.sdxl_refiner_model, - input=Input.Direct, - ui_type=UIType.SDXLRefinerModel, + model: ModelField = InputField( + description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel ) # TODO: precision? @@ -115,34 +84,14 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): if not context.models.exists(model_key): raise Exception(f"Unknown model: {model_key}") + unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet}) + scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler}) + tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) + text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) + vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) + return SDXLRefinerModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - key=model_key, - submodel_type=SubModelType.UNet, - ), - scheduler=ModelInfo( - key=model_key, - submodel_type=SubModelType.Scheduler, - ), - loras=[], - ), - clip2=ClipField( - tokenizer=ModelInfo( - key=model_key, - submodel_type=SubModelType.Tokenizer2, - ), - text_encoder=ModelInfo( - key=model_key, - submodel_type=SubModelType.TextEncoder2, - ), - loras=[], - skipped_layers=0, - ), - vae=VaeField( - vae=ModelInfo( - key=model_key, - submodel_type=SubModelType.VAE, - ), - ), + unet=UNetField(unet=unet, scheduler=scheduler, loras=[]), + clip2=ClipField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0), + vae=VaeField(vae=vae), ) diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 0f1e251bb3..4b1c5e36b4 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -10,17 +10,14 @@ from invokeai.app.invocations.baseinvocation import ( ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField +from invokeai.app.invocations.model import ModelField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -class T2IAdapterModelField(BaseModel): - key: str = Field(description="Model record key for the T2I-Adapter model") - - class T2IAdapterField(BaseModel): image: ImageField = Field(description="The T2I-Adapter image prompt.") - t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.") + t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.") weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)" @@ -55,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation): # Inputs image: ImageField = InputField(description="The IP-Adapter image prompt.") - t2i_adapter_model: T2IAdapterModelField = InputField( + t2i_adapter_model: ModelField = InputField( description="The T2I-Adapter model.", title="T2I-Adapter Model", input=Input.Direct, diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 83632d0c0f..bf5bf44ca3 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,7 +1,6 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" - import torch from typing_extensions import Self diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 60f4e36b4d..0b56b7f6c0 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -130,6 +130,17 @@ class ModelRecordServiceBase(ABC): """ pass + @abstractmethod + def get_model_by_hash(self, hash: str) -> AnyModelConfig: + """ + Retrieve the configuration for the indicated model. + + :param hash: Hash of model config to be fetched. + + Exceptions: UnknownModelException + """ + pass + @abstractmethod def list_models( self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 35c182fb9d..51f704a5e9 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -203,6 +203,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) return model + def get_model_by_hash(self, hash: str) -> AnyModelConfig: + with self._db.lock: + self._cursor.execute( + """--sql + SELECT config, strftime('%s',updated_at) FROM models + WHERE hash=?; + """, + (hash,), + ) + rows = self._cursor.fetchone() + if not rows: + raise UnknownModelException("model not found") + model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1]) + return model + def exists(self, key: str) -> bool: """ Return True if a model with the indicated key exists in the databse. diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 7d378e22e3..abf131a125 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,7 +1,7 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from PIL.Image import Image from torch import Tensor @@ -13,15 +13,16 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.load.load_base import LoadedModel -from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation + from invokeai.app.invocations.model import ModelField from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem """ @@ -299,22 +300,25 @@ class ConditioningInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface): - def exists(self, key: str) -> bool: + def exists(self, identifier: Union[str, "ModelField"]) -> bool: """Checks if a model exists. Args: - key: The key of the model. + identifier: The key or ModelField representing the model. Returns: True if the model exists, False if not. """ - return self._services.model_manager.store.exists(key) + if isinstance(identifier, str): + return self._services.model_manager.store.exists(identifier) - def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + return self._services.model_manager.store.exists(identifier.key) + + def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel: """Loads a model. Args: - key: The key of the model. + identifier: The key or ModelField representing the model. submodel_type: The submodel of the model to get. Returns: @@ -324,9 +328,13 @@ class ModelsInterface(InvocationContextInterface): # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. - return self._services.model_manager.load_model_by_key( - key=key, submodel_type=submodel_type, context_data=self._data - ) + if isinstance(identifier, str): + model = self._services.model_manager.store.get_model(identifier) + return self._services.model_manager.load.load_model(model, submodel_type, self._data) + else: + _submodel_type = submodel_type or identifier.submodel_type + model = self._services.model_manager.store.get_model(identifier.key) + return self._services.model_manager.load.load_model(model, _submodel_type, self._data) def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None @@ -343,35 +351,29 @@ class ModelsInterface(InvocationContextInterface): Returns: An object representing the loaded model. """ - return self._services.model_manager.load_model_by_attr( - model_name=name, - base_model=base, - model_type=type, - submodel=submodel_type, - context_data=self._data, - ) - def get_config(self, key: str) -> AnyModelConfig: + configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type) + if len(configs) == 0: + raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}") + + if len(configs) > 1: + raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}") + + return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data) + + def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig: """Gets a model's config. Args: - key: The key of the model. + identifier: The key or ModelField representing the model. Returns: The model's config. """ - return self._services.model_manager.store.get_model(key=key) + if isinstance(identifier, str): + return self._services.model_manager.store.get_model(identifier) - def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: - """Gets a model's metadata, if it has any. - - Args: - key: The key of the model. - - Returns: - The model's metadata, if it has any. - """ - return self._services.model_manager.store.get_metadata(key=key) + return self._services.model_manager.store.get_model(identifier.key) def search_by_path(self, path: Path) -> list[AnyModelConfig]: """Searches for models by path. diff --git a/invokeai/app/util/ti_utils.py b/invokeai/app/util/ti_utils.py index d204a40183..34669fe64e 100644 --- a/invokeai/app/util/ti_utils.py +++ b/invokeai/app/util/ti_utils.py @@ -22,7 +22,7 @@ def generate_ti_list( for trigger in extract_ti_triggers_from_prompt(prompt): name_or_key = trigger[1:-1] try: - loaded_model = context.models.load(key=name_or_key) + loaded_model = context.models.load(name_or_key) model = loaded_model.model assert isinstance(model, TextualInversionModelRaw) assert loaded_model.config.base == base diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index e110b5a2db..492b5c1f4c 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -35,17 +35,13 @@ from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOut from invokeai.app.invocations.model import ( ClipField, CLIPOutput, - LoraInfo, LoraLoaderOutput, - LoRAModelField, - MainModelField, - ModelInfo, + ModelField, ModelLoaderOutput, SDXLLoraLoaderOutput, UNetField, UNetOutput, VaeField, - VAEModelField, VAEOutput, ) from invokeai.app.invocations.primitives import ( @@ -73,8 +69,8 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.misc import SEED_MAX, get_random_seed -from invokeai.backend.model_management.model_manager import LoadedModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, @@ -118,14 +114,10 @@ __all__ = [ "MetadataItemOutput", "MetadataOutput", # invokeai.app.invocations.model - "ModelInfo", - "LoraInfo", + "ModelField", "UNetField", "ClipField", "VaeField", - "MainModelField", - "LoRAModelField", - "VAEModelField", "UNetOutput", "VAEOutput", "CLIPOutput", @@ -166,7 +158,7 @@ __all__ = [ # invokeai.app.services.config.config_default "InvokeAIAppConfig", # invokeai.backend.model_management.model_manager - "LoadedModelInfo", + "LoadedModel", # invokeai.backend.model_management.models.base "BaseModelType", "ModelType",