diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 3850fb6cc3..5159d5b89c 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -69,20 +69,12 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.load.load_model_by_key( - **self.clip.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.load.load_model_by_key( - **self.clip.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_manager.load.load_model_by_key( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) assert isinstance(lora_info.model, LoRAModelRaw) yield (lora_info.model, lora.weight) del lora_info @@ -94,10 +86,7 @@ class CompelInvocation(BaseInvocation): for trigger in extract_ti_triggers_from_prompt(self.prompt): name = trigger[1:-1] try: - loaded_model = context.services.model_manager.load.load_model_by_key( - **self.clip.text_encoder.model_dump(), - context=context, - ).model + loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model assert isinstance(loaded_model, TextualInversionModelRaw) ti_list.append((name, loaded_model)) except UnknownModelException: @@ -165,14 +154,8 @@ class SDXLPromptInvocationBase: lora_prefix: str, zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: - tokenizer_info = context.services.model_manager.load.load_model_by_key( - **clip_field.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.load.load_model_by_key( - **clip_field.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) # return zero on empty if prompt == "" and zero_on_empty: @@ -197,9 +180,7 @@ class SDXLPromptInvocationBase: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_manager.load.load_model_by_key( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) lora_model = lora_info.model assert isinstance(lora_model, LoRAModelRaw) yield (lora_model, lora.weight) @@ -212,11 +193,8 @@ class SDXLPromptInvocationBase: for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_model = context.services.model_manager.load.load_model_by_attr( - model_name=name, - base_model=text_encoder_info.config.base, - model_type=ModelType.TextualInversion, - context=context, + ti_model = context.models.load_by_attrs( + model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion ).model assert isinstance(ti_model, TextualInversionModelRaw) ti_list.append((name, ti_model)) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 01124f62f3..15e254010b 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -14,8 +14,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_management.models.base import BaseModelType, ModelType -from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id +from invokeai.backend.model_manager.config import BaseModelType, ModelType # LS: Consider moving these two classes into model.py @@ -90,10 +89,10 @@ class IPAdapterInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key) + ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() - image_encoder_models = context.services.model_manager.store.search_by_attr( + image_encoder_models = context.models.search_by_attrs( model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision ) assert len(image_encoder_models) == 1 diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 05293fdfee..5dd0eb074d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -141,7 +141,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: if self.image is not None: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: image_tensor = image_tensor.unsqueeze(0) @@ -153,10 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation): ) if image_tensor is not None: - vae_info = context.services.model_manager.load.load_model_by_key( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) @@ -182,10 +179,7 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.load.load_model_by_key( - **scheduler_info.model_dump(), - context=context, - ) + orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -399,12 +393,7 @@ class DenoiseLatentsInvocation(BaseInvocation): # and if weight is None, populate with default 1.0? controlnet_data = [] for control_info in control_list: - control_model = exit_stack.enter_context( - context.services.model_manager.load.load_model_by_key( - key=control_info.control_model.key, - context=context, - ) - ) + control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key)) # control_models.append(control_model) control_image_field = control_info.image @@ -466,25 +455,17 @@ class DenoiseLatentsInvocation(BaseInvocation): conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.load.load_model_by_key( - key=single_ip_adapter.ip_adapter_model.key, - context=context, - ) + context.models.load(key=single_ip_adapter.ip_adapter_model.key) ) - image_encoder_model_info = context.services.model_manager.load.load_model_by_key( - key=single_ip_adapter.image_encoder_model.key, - context=context, - ) + image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. single_ipa_image_fields = single_ip_adapter.image if not isinstance(single_ipa_image_fields, list): single_ipa_image_fields = [single_ipa_image_fields] - single_ipa_images = [ - context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields - ] + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -528,10 +509,7 @@ class DenoiseLatentsInvocation(BaseInvocation): t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key( - key=t2i_adapter_field.t2i_adapter_model.key, - context=context, - ) + t2i_adapter_model_info = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key) image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. @@ -676,30 +654,20 @@ class DenoiseLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - # get the unet's config so that we can pass the base to dispatch_progress() - unet_config = context.services.model_manager.store.get_model(self.unet.unet.key) + unet_config = context.models.get_config(self.unet.unet.key) def step_callback(state: PipelineIntermediateState) -> None: - self.dispatch_progress(context, source_node_id, state, unet_config.base) + context.util.sd_step_callback(state, unet_config.base) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: - lora_info = context.services.model_manager.load.load_model_by_key( - **lora.model_dump(exclude={"weight"}), - context=context, - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.load.load_model_by_key( - **self.unet.unet.model_dump(), - context=context, - ) + unet_info = context.models.load(**self.unet.unet.model_dump()) assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, @@ -806,10 +774,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.tensors.load(self.latents.latents_name) - vae_info = context.services.model_manager.load.load_model_by_key( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: assert isinstance(vae, torch.nn.Module) @@ -1032,10 +997,7 @@ class ImageToLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_manager.load.load_model_by_key( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: @@ -1239,10 +1201,7 @@ class IdealSizeInvocation(BaseInvocation): return tuple((x - x % multiple_of) for x in args) def invoke(self, context: InvocationContext) -> IdealSizeOutput: - unet_config = context.services.model_manager.load.load_model_by_key( - **self.unet.unet.model_dump(), - context=context, - ) + unet_config = context.models.get_config(**self.unet.unet.model_dump()) aspect = self.width / self.height dimension: float = 512 if unet_config.base == BaseModelType.StableDiffusion2: diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 71a71a63c8..6087bc82db 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -103,7 +103,7 @@ class MainModelLoaderInvocation(BaseInvocation): key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.store.exists(key): + if not context.models.exists(key): raise Exception(f"Unknown model {key}") return ModelLoaderOutput( @@ -172,7 +172,7 @@ class LoraLoaderInvocation(BaseInvocation): lora_key = self.lora.key - if not context.services.model_manager.store.exists(lora_key): + if not context.models.exists(lora_key): raise Exception(f"Unkown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -252,7 +252,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): lora_key = self.lora.key - if not context.services.model_manager.store.exists(lora_key): + if not context.models.exists(lora_key): raise Exception(f"Unknown lora: {lora_key}!") if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): @@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> VAEOutput: key = self.vae_model.key - if not context.services.model_manager.store.exists(key): + if not context.models.exists(key): raise Exception(f"Unkown vae: {key}!") return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 85e6fb787f..0df27c0011 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -43,7 +43,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.store.exists(model_key): + if not context.models.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLModelLoaderOutput( @@ -112,7 +112,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): model_key = self.model.key # TODO: not found exceptions - if not context.services.model_manager.store.exists(model_key): + if not context.models.exists(model_key): raise Exception(f"Unknown model: {model_key}") return SDXLRefinerModelLoaderOutput( diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 45eaf4652f..f4dd905135 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -19,14 +19,14 @@ class ModelLoadServiceBase(ABC): self, key: str, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's key, load it and return the LoadedModel object. :param key: Key of model config to be fetched. :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting + :param context_data: Invocation context data used for event reporting """ pass @@ -35,14 +35,14 @@ class ModelLoadServiceBase(ABC): self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting + :param context_data: Invocation context data used for event reporting """ pass @@ -53,7 +53,7 @@ class ModelLoadServiceBase(ABC): base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. @@ -66,7 +66,7 @@ class ModelLoadServiceBase(ABC): :param base_model: Base model :param model_type: Type of the model :param submodel: For main (pipeline models), the submodel to fetch - :param context: The invocation context. + :param context_data: The invocation context data. Exceptions: UnknownModelException -- model with these attributes not known NotImplementedException -- a model loader was not provided at initialization time diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index a6ccd5afbc..29b297c814 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -3,10 +3,11 @@ from typing import Optional -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -46,6 +47,9 @@ class ModelLoadService(ModelLoadServiceBase): ), ) + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + @property def ram_cache(self) -> ModelCacheBase[AnyModel]: """Return the RAM cache used by this loader.""" @@ -60,7 +64,7 @@ class ModelLoadService(ModelLoadServiceBase): self, key: str, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's key, load it and return the LoadedModel object. @@ -70,7 +74,7 @@ class ModelLoadService(ModelLoadServiceBase): :param context: Invocation context used for event reporting """ config = self._store.get_model(key) - return self.load_model_by_config(config, submodel_type, context) + return self.load_model_by_config(config, submodel_type, context_data) def load_model_by_attr( self, @@ -78,7 +82,7 @@ class ModelLoadService(ModelLoadServiceBase): base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. @@ -109,7 +113,7 @@ class ModelLoadService(ModelLoadServiceBase): self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. @@ -118,15 +122,15 @@ class ModelLoadService(ModelLoadServiceBase): :param submodel: For main (pipeline models), the submodel to fetch. :param context: Invocation context used for event reporting """ - if context: + if context_data: self._emit_load_event( - context=context, + context_data=context_data, model_config=model_config, ) loaded_model = self._any_loader.load_model(model_config, submodel_type) - if context: + if context_data: self._emit_load_event( - context=context, + context_data=context_data, model_config=model_config, loaded=True, ) @@ -134,26 +138,28 @@ class ModelLoadService(ModelLoadServiceBase): def _emit_load_event( self, - context: InvocationContext, + context_data: InvocationContextData, model_config: AnyModelConfig, loaded: Optional[bool] = False, ) -> None: - if context.services.queue.is_canceled(context.graph_execution_state_id): + if not self._invoker: + return + if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() if not loaded: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_started( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_config=model_config, ) else: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_completed( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_config=model_config, ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c68dc1140b..089d09f825 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Optional from PIL.Image import Image @@ -12,8 +13,9 @@ from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend.model_management.model_manager import LoadedModelInfo -from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -259,45 +261,95 @@ class ConditioningInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface): - def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: + def exists(self, key: str) -> bool: """ Checks if a model exists. - :param model_name: The name of the model to check. - :param base_model: The base model of the model to check. - :param model_type: The type of the model to check. + :param key: The key of the model. """ - return self._services.model_manager.model_exists(model_name, base_model, model_type) + return self._services.model_manager.store.exists(key) - def load( - self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None - ) -> LoadedModelInfo: + def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Loads a model. - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. - :param submodel: The submodel of the model to get. + :param key: The key of the model. + :param submodel_type: The submodel of the model to get. :returns: An object representing the loaded model. """ # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. - return self._services.model_manager.get_model( - model_name, base_model, model_type, submodel, context_data=self._context_data + return self._services.model_manager.load.load_model_by_key( + key=key, submodel_type=submodel_type, context_data=self._context_data ) - def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + def load_by_attrs( + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + ) -> LoadedModel: + """ + Loads a model by its attributes. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + """ + return self._services.model_manager.load.load_model_by_attr( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=submodel, + context_data=self._context_data, + ) + + def get_config(self, key: str) -> AnyModelConfig: """ Gets a model's info, an dict-like object. - :param model_name: The name of the model to get. - :param base_model: The base model of the model to get. - :param model_type: The type of the model to get. + :param key: The key of the model. """ - return self._services.model_manager.model_info(model_name, base_model, model_type) + return self._services.model_manager.store.get_model(key=key) + + def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]: + """ + Gets a model's metadata, if it has any. + + :param key: The key of the model. + """ + return self._services.model_manager.store.get_metadata(key=key) + + def search_by_path(self, path: Path) -> list[AnyModelConfig]: + """ + Searches for models by path. + + :param path: The path to search for. + """ + return self._services.model_manager.store.search_by_path(path) + + def search_by_attrs( + self, + model_name: Optional[str] = None, + base_model: Optional[BaseModelType] = None, + model_type: Optional[ModelType] = None, + model_format: Optional[ModelFormat] = None, + ) -> list[AnyModelConfig]: + """ + Searches for models by attributes. + + :param model_name: Name of to be fetched. + :param base_model: Base model + :param model_type: Type of the model + :param submodel: For main (pipeline models), the submodel to fetch + """ + + return self._services.model_manager.store.search_by_attr( + model_name=model_name, + base_model=base_model, + model_type=model_type, + model_format=model_format, + ) class ConfigInterface(InvocationContextInterface): diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index d83b380d95..33d00ca366 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -4,8 +4,8 @@ import torch from PIL import Image from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage +from invokeai.backend.model_manager.config import BaseModelType -from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL