feat(nodes): update invocation context for mm2, update nodes model usage

This commit is contained in:
psychedelicious 2024-02-15 20:43:41 +11:00 committed by Brandon Rising
parent 7a36cd2832
commit 8958e820c8
9 changed files with 141 additions and 147 deletions

View File

@ -69,20 +69,12 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.load.load_model_by_key( tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
**self.clip.tokenizer.model_dump(), text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
context=context,
)
text_encoder_info = context.services.model_manager.load.load_model_by_key(
**self.clip.text_encoder.model_dump(),
context=context,
)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.services.model_manager.load.load_model_by_key( lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
**lora.model_dump(exclude={"weight"}), context=context
)
assert isinstance(lora_info.model, LoRAModelRaw) assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight) yield (lora_info.model, lora.weight)
del lora_info del lora_info
@ -94,10 +86,7 @@ class CompelInvocation(BaseInvocation):
for trigger in extract_ti_triggers_from_prompt(self.prompt): for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
loaded_model = context.services.model_manager.load.load_model_by_key( loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model
**self.clip.text_encoder.model_dump(),
context=context,
).model
assert isinstance(loaded_model, TextualInversionModelRaw) assert isinstance(loaded_model, TextualInversionModelRaw)
ti_list.append((name, loaded_model)) ti_list.append((name, loaded_model))
except UnknownModelException: except UnknownModelException:
@ -165,14 +154,8 @@ class SDXLPromptInvocationBase:
lora_prefix: str, lora_prefix: str,
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.services.model_manager.load.load_model_by_key( tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
**clip_field.tokenizer.model_dump(), text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
context=context,
)
text_encoder_info = context.services.model_manager.load.load_model_by_key(
**clip_field.text_encoder.model_dump(),
context=context,
)
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
@ -197,9 +180,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.load.load_model_by_key( lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
**lora.model_dump(exclude={"weight"}), context=context
)
lora_model = lora_info.model lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw) assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight) yield (lora_model, lora.weight)
@ -212,11 +193,8 @@ class SDXLPromptInvocationBase:
for trigger in extract_ti_triggers_from_prompt(prompt): for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_model = context.services.model_manager.load.load_model_by_attr( ti_model = context.models.load_by_attrs(
model_name=name, model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
base_model=text_encoder_info.config.base,
model_type=ModelType.TextualInversion,
context=context,
).model ).model
assert isinstance(ti_model, TextualInversionModelRaw) assert isinstance(ti_model, TextualInversionModelRaw)
ti_list.append((name, ti_model)) ti_list.append((name, ti_model))

View File

@ -14,8 +14,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_management.models.base import BaseModelType, ModelType from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
# LS: Consider moving these two classes into model.py # LS: Consider moving these two classes into model.py
@ -90,10 +89,10 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput: def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key) ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.services.model_manager.store.search_by_attr( image_encoder_models = context.models.search_by_attrs(
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
) )
assert len(image_encoder_models) == 1 assert len(image_encoder_models) == 1

View File

@ -141,7 +141,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None: if self.image is not None:
image = context.services.images.get_pil_image(self.image.image_name) image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3: if image_tensor.dim() == 3:
image_tensor = image_tensor.unsqueeze(0) image_tensor = image_tensor.unsqueeze(0)
@ -153,10 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
) )
if image_tensor is not None: if image_tensor is not None:
vae_info = context.services.model_manager.load.load_model_by_key( vae_info = context.models.load(**self.vae.vae.model_dump())
**self.vae.vae.model_dump(),
context=context,
)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0) masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
@ -182,10 +179,7 @@ def get_scheduler(
seed: int, seed: int,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.services.model_manager.load.load_model_by_key( orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
**scheduler_info.model_dump(),
context=context,
)
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@ -399,12 +393,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# and if weight is None, populate with default 1.0? # and if weight is None, populate with default 1.0?
controlnet_data = [] controlnet_data = []
for control_info in control_list: for control_info in control_list:
control_model = exit_stack.enter_context( control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
context.services.model_manager.load.load_model_by_key(
key=control_info.control_model.key,
context=context,
)
)
# control_models.append(control_model) # control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
@ -466,25 +455,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = [] conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter: for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.load.load_model_by_key( context.models.load(key=single_ip_adapter.ip_adapter_model.key)
key=single_ip_adapter.ip_adapter_model.key,
context=context,
)
) )
image_encoder_model_info = context.services.model_manager.load.load_model_by_key( image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
key=single_ip_adapter.image_encoder_model.key,
context=context,
)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list): if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields] single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [ single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
]
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
@ -528,10 +509,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = [] t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter: for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key( t2i_adapter_model_info = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
key=t2i_adapter_field.t2i_adapter_model.key,
context=context,
)
image = context.images.get_pil(t2i_adapter_field.image.image_name) image = context.images.get_pil(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@ -676,30 +654,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
) )
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
# get the unet's config so that we can pass the base to dispatch_progress() # get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.services.model_manager.store.get_model(self.unet.unet.key) unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None: def step_callback(state: PipelineIntermediateState) -> None:
self.dispatch_progress(context, source_node_id, state, unet_config.base) context.util.sd_step_callback(state, unet_config.base)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.load.load_model_by_key( lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
**lora.model_dump(exclude={"weight"}),
context=context,
)
yield (lora_info.model, lora.weight) yield (lora_info.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.load.load_model_by_key( unet_info = context.models.load(**self.unet.unet.model_dump())
**self.unet.unet.model_dump(),
context=context,
)
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
@ -806,10 +774,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name) latents = context.tensors.load(self.latents.latents_name)
vae_info = context.services.model_manager.load.load_model_by_key( vae_info = context.models.load(**self.vae.vae.model_dump())
**self.vae.vae.model_dump(),
context=context,
)
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, torch.nn.Module) assert isinstance(vae, torch.nn.Module)
@ -1032,10 +997,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name) image = context.images.get_pil(self.image.image_name)
vae_info = context.services.model_manager.load.load_model_by_key( vae_info = context.models.load(**self.vae.vae.model_dump())
**self.vae.vae.model_dump(),
context=context,
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3: if image_tensor.dim() == 3:
@ -1239,10 +1201,7 @@ class IdealSizeInvocation(BaseInvocation):
return tuple((x - x % multiple_of) for x in args) return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput: def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.services.model_manager.load.load_model_by_key( unet_config = context.models.get_config(**self.unet.unet.model_dump())
**self.unet.unet.model_dump(),
context=context,
)
aspect = self.width / self.height aspect = self.width / self.height
dimension: float = 512 dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2: if unet_config.base == BaseModelType.StableDiffusion2:

View File

@ -103,7 +103,7 @@ class MainModelLoaderInvocation(BaseInvocation):
key = self.model.key key = self.model.key
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.store.exists(key): if not context.models.exists(key):
raise Exception(f"Unknown model {key}") raise Exception(f"Unknown model {key}")
return ModelLoaderOutput( return ModelLoaderOutput(
@ -172,7 +172,7 @@ class LoraLoaderInvocation(BaseInvocation):
lora_key = self.lora.key lora_key = self.lora.key
if not context.services.model_manager.store.exists(lora_key): if not context.models.exists(lora_key):
raise Exception(f"Unkown lora: {lora_key}!") raise Exception(f"Unkown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
@ -252,7 +252,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
lora_key = self.lora.key lora_key = self.lora.key
if not context.services.model_manager.store.exists(lora_key): if not context.models.exists(lora_key):
raise Exception(f"Unknown lora: {lora_key}!") raise Exception(f"Unknown lora: {lora_key}!")
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras): if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> VAEOutput: def invoke(self, context: InvocationContext) -> VAEOutput:
key = self.vae_model.key key = self.vae_model.key
if not context.services.model_manager.store.exists(key): if not context.models.exists(key):
raise Exception(f"Unkown vae: {key}!") raise Exception(f"Unkown vae: {key}!")
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key))) return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))

View File

@ -43,7 +43,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
model_key = self.model.key model_key = self.model.key
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.store.exists(model_key): if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
return SDXLModelLoaderOutput( return SDXLModelLoaderOutput(
@ -112,7 +112,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
model_key = self.model.key model_key = self.model.key
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.store.exists(model_key): if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}") raise Exception(f"Unknown model: {model_key}")
return SDXLRefinerModelLoaderOutput( return SDXLRefinerModelLoaderOutput(

View File

@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
@ -19,14 +19,14 @@ class ModelLoadServiceBase(ABC):
self, self,
key: str, key: str,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None, context_data: Optional[InvocationContextData] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Given a model's key, load it and return the LoadedModel object. Given a model's key, load it and return the LoadedModel object.
:param key: Key of model config to be fetched. :param key: Key of model config to be fetched.
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting :param context_data: Invocation context data used for event reporting
""" """
pass pass
@ -35,14 +35,14 @@ class ModelLoadServiceBase(ABC):
self, self,
model_config: AnyModelConfig, model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None, context_data: Optional[InvocationContextData] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Given a model's configuration, load it and return the LoadedModel object. Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting :param context_data: Invocation context data used for event reporting
""" """
pass pass
@ -53,7 +53,7 @@ class ModelLoadServiceBase(ABC):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None, context_data: Optional[InvocationContextData] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
@ -66,7 +66,7 @@ class ModelLoadServiceBase(ABC):
:param base_model: Base model :param base_model: Base model
:param model_type: Type of the model :param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch :param submodel: For main (pipeline models), the submodel to fetch
:param context: The invocation context. :param context_data: The invocation context data.
Exceptions: UnknownModelException -- model with these attributes not known Exceptions: UnknownModelException -- model with these attributes not known
NotImplementedException -- a model loader was not provided at initialization time NotImplementedException -- a model loader was not provided at initialization time

View File

@ -3,10 +3,11 @@
from typing import Optional from typing import Optional
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
@ -46,6 +47,9 @@ class ModelLoadService(ModelLoadServiceBase):
), ),
) )
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@property @property
def ram_cache(self) -> ModelCacheBase[AnyModel]: def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader.""" """Return the RAM cache used by this loader."""
@ -60,7 +64,7 @@ class ModelLoadService(ModelLoadServiceBase):
self, self,
key: str, key: str,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None, context_data: Optional[InvocationContextData] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Given a model's key, load it and return the LoadedModel object. Given a model's key, load it and return the LoadedModel object.
@ -70,7 +74,7 @@ class ModelLoadService(ModelLoadServiceBase):
:param context: Invocation context used for event reporting :param context: Invocation context used for event reporting
""" """
config = self._store.get_model(key) config = self._store.get_model(key)
return self.load_model_by_config(config, submodel_type, context) return self.load_model_by_config(config, submodel_type, context_data)
def load_model_by_attr( def load_model_by_attr(
self, self,
@ -78,7 +82,7 @@ class ModelLoadService(ModelLoadServiceBase):
base_model: BaseModelType, base_model: BaseModelType,
model_type: ModelType, model_type: ModelType,
submodel: Optional[SubModelType] = None, submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None, context_data: Optional[InvocationContextData] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
@ -109,7 +113,7 @@ class ModelLoadService(ModelLoadServiceBase):
self, self,
model_config: AnyModelConfig, model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None, context_data: Optional[InvocationContextData] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Given a model's configuration, load it and return the LoadedModel object. Given a model's configuration, load it and return the LoadedModel object.
@ -118,15 +122,15 @@ class ModelLoadService(ModelLoadServiceBase):
:param submodel: For main (pipeline models), the submodel to fetch. :param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting :param context: Invocation context used for event reporting
""" """
if context: if context_data:
self._emit_load_event( self._emit_load_event(
context=context, context_data=context_data,
model_config=model_config, model_config=model_config,
) )
loaded_model = self._any_loader.load_model(model_config, submodel_type) loaded_model = self._any_loader.load_model(model_config, submodel_type)
if context: if context_data:
self._emit_load_event( self._emit_load_event(
context=context, context_data=context_data,
model_config=model_config, model_config=model_config,
loaded=True, loaded=True,
) )
@ -134,26 +138,28 @@ class ModelLoadService(ModelLoadServiceBase):
def _emit_load_event( def _emit_load_event(
self, self,
context: InvocationContext, context_data: InvocationContextData,
model_config: AnyModelConfig, model_config: AnyModelConfig,
loaded: Optional[bool] = False, loaded: Optional[bool] = False,
) -> None: ) -> None:
if context.services.queue.is_canceled(context.graph_execution_state_id): if not self._invoker:
return
if self._invoker.services.queue.is_canceled(context_data.session_id):
raise CanceledException() raise CanceledException()
if not loaded: if not loaded:
context.services.events.emit_model_load_started( self._invoker.services.events.emit_model_load_started(
queue_id=context.queue_id, queue_id=context_data.queue_id,
queue_item_id=context.queue_item_id, queue_item_id=context_data.queue_item_id,
queue_batch_id=context.queue_batch_id, queue_batch_id=context_data.batch_id,
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context_data.session_id,
model_config=model_config, model_config=model_config,
) )
else: else:
context.services.events.emit_model_load_completed( self._invoker.services.events.emit_model_load_completed(
queue_id=context.queue_id, queue_id=context_data.queue_id,
queue_item_id=context.queue_item_id, queue_item_id=context_data.queue_item_id,
queue_batch_id=context.queue_batch_id, queue_batch_id=context_data.batch_id,
graph_execution_state_id=context.graph_execution_state_id, graph_execution_state_id=context_data.session_id,
model_config=model_config, model_config=model_config,
) )

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from PIL.Image import Image from PIL.Image import Image
@ -12,8 +13,9 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.model_manager import LoadedModelInfo from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -259,45 +261,95 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface):
def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool: def exists(self, key: str) -> bool:
""" """
Checks if a model exists. Checks if a model exists.
:param model_name: The name of the model to check. :param key: The key of the model.
:param base_model: The base model of the model to check.
:param model_type: The type of the model to check.
""" """
return self._services.model_manager.model_exists(model_name, base_model, model_type) return self._services.model_manager.store.exists(key)
def load( def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
) -> LoadedModelInfo:
""" """
Loads a model. Loads a model.
:param model_name: The name of the model to get. :param key: The key of the model.
:param base_model: The base model of the model to get. :param submodel_type: The submodel of the model to get.
:param model_type: The type of the model to get.
:param submodel: The submodel of the model to get.
:returns: An object representing the loaded model. :returns: An object representing the loaded model.
""" """
# The model manager emits events as it loads the model. It needs the context data to build # The model manager emits events as it loads the model. It needs the context data to build
# the event payloads. # the event payloads.
return self._services.model_manager.get_model( return self._services.model_manager.load.load_model_by_key(
model_name, base_model, model_type, submodel, context_data=self._context_data key=key, submodel_type=submodel_type, context_data=self._context_data
) )
def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def load_by_attrs(
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
) -> LoadedModel:
"""
Loads a model by its attributes.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
"""
return self._services.model_manager.load.load_model_by_attr(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
context_data=self._context_data,
)
def get_config(self, key: str) -> AnyModelConfig:
""" """
Gets a model's info, an dict-like object. Gets a model's info, an dict-like object.
:param model_name: The name of the model to get. :param key: The key of the model.
:param base_model: The base model of the model to get.
:param model_type: The type of the model to get.
""" """
return self._services.model_manager.model_info(model_name, base_model, model_type) return self._services.model_manager.store.get_model(key=key)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Gets a model's metadata, if it has any.
:param key: The key of the model.
"""
return self._services.model_manager.store.get_metadata(key=key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""
Searches for models by path.
:param path: The path to search for.
"""
return self._services.model_manager.store.search_by_path(path)
def search_by_attrs(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
"""
Searches for models by attributes.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
"""
return self._services.model_manager.store.search_by_attr(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_format=model_format,
)
class ConfigInterface(InvocationContextInterface): class ConfigInterface(InvocationContextInterface):

View File

@ -4,8 +4,8 @@ import torch
from PIL import Image from PIL import Image
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
from invokeai.backend.model_manager.config import BaseModelType
from ...backend.model_management.models import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL from ...backend.util.util import image_to_dataURL