mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): update invocation context for mm2, update nodes model usage
This commit is contained in:
parent
7a36cd2832
commit
8958e820c8
@ -69,20 +69,12 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
@ -94,10 +86,7 @@ class CompelInvocation(BaseInvocation):
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.services.model_manager.load.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
).model
|
||||
loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model
|
||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, loaded_model))
|
||||
except UnknownModelException:
|
||||
@ -165,14 +154,8 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@ -197,9 +180,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
@ -212,11 +193,8 @@ class SDXLPromptInvocationBase:
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_model = context.services.model_manager.load.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=text_encoder_info.config.base,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
ti_model = context.models.load_by_attrs(
|
||||
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
||||
).model
|
||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, ti_model))
|
||||
|
@ -14,8 +14,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
|
||||
|
||||
# LS: Consider moving these two classes into model.py
|
||||
@ -90,10 +89,10 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key)
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_models = context.services.model_manager.store.search_by_attr(
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
|
@ -141,7 +141,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
@ -153,10 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
if image_tensor is not None:
|
||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
@ -182,10 +179,7 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.load.load_model_by_key(
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@ -399,12 +393,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.load.load_model_by_key(
|
||||
key=control_info.control_model.key,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
@ -466,25 +455,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_manager.load.load_model_by_key(
|
||||
key=single_ip_adapter.ip_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.services.model_manager.load.load_model_by_key(
|
||||
key=single_ip_adapter.image_encoder_model.key,
|
||||
context=context,
|
||||
)
|
||||
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [
|
||||
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
|
||||
]
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
@ -528,10 +509,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key(
|
||||
key=t2i_adapter_field.t2i_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
t2i_adapter_model_info = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
@ -676,30 +654,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||
unet_config = context.services.model_manager.store.get_model(self.unet.unet.key)
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
self.dispatch_progress(context, source_node_id, state, unet_config.base)
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
@ -806,10 +774,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
@ -1032,10 +997,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
@ -1239,10 +1201,7 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||
unet_config = context.services.model_manager.load.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
unet_config = context.models.get_config(**self.unet.unet.model_dump())
|
||||
aspect = self.width / self.height
|
||||
dimension: float = 512
|
||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||
|
@ -103,7 +103,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
return ModelLoaderOutput(
|
||||
@ -172,7 +172,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
@ -252,7 +252,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
if not context.models.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
key = self.vae_model.key
|
||||
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
if not context.models.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
|
@ -43,7 +43,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
@ -112,7 +112,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
|
@ -4,7 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
@ -19,14 +19,14 @@ class ModelLoadServiceBase(ABC):
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's key, load it and return the LoadedModel object.
|
||||
|
||||
:param key: Key of model config to be fetched.
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
:param context_data: Invocation context data used for event reporting
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -35,14 +35,14 @@ class ModelLoadServiceBase(ABC):
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
:param context_data: Invocation context data used for event reporting
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -53,7 +53,7 @@ class ModelLoadServiceBase(ABC):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
@ -66,7 +66,7 @@ class ModelLoadServiceBase(ABC):
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
:param context_data: The invocation context data.
|
||||
|
||||
Exceptions: UnknownModelException -- model with these attributes not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
|
@ -3,10 +3,11 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
@ -46,6 +47,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
),
|
||||
)
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
@ -60,7 +64,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's key, load it and return the LoadedModel object.
|
||||
@ -70,7 +74,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
config = self._store.get_model(key)
|
||||
return self.load_model_by_config(config, submodel_type, context)
|
||||
return self.load_model_by_config(config, submodel_type, context_data)
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
@ -78,7 +82,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
@ -109,7 +113,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
@ -118,15 +122,15 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
if context:
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
)
|
||||
loaded_model = self._any_loader.load_model(model_config, submodel_type)
|
||||
if context:
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
loaded=True,
|
||||
)
|
||||
@ -134,26 +138,28 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
context_data: InvocationContextData,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
) -> None:
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
if not self._invoker:
|
||||
return
|
||||
if self._invoker.services.queue.is_canceled(context_data.session_id):
|
||||
raise CanceledException()
|
||||
|
||||
if not loaded:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=context_data.queue_id,
|
||||
queue_item_id=context_data.queue_item_id,
|
||||
queue_batch_id=context_data.batch_id,
|
||||
graph_execution_state_id=context_data.session_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=context_data.queue_id,
|
||||
queue_item_id=context_data.queue_item_id,
|
||||
queue_batch_id=context_data.batch_id,
|
||||
graph_execution_state_id=context_data.session_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
@ -12,8 +13,9 @@ from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
@ -259,45 +261,95 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Checks if a model exists.
|
||||
|
||||
:param model_name: The name of the model to check.
|
||||
:param base_model: The base model of the model to check.
|
||||
:param model_type: The type of the model to check.
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.model_exists(model_name, base_model, model_type)
|
||||
return self._services.model_manager.store.exists(key)
|
||||
|
||||
def load(
|
||||
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||
) -> LoadedModelInfo:
|
||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Loads a model.
|
||||
|
||||
:param model_name: The name of the model to get.
|
||||
:param base_model: The base model of the model to get.
|
||||
:param model_type: The type of the model to get.
|
||||
:param submodel: The submodel of the model to get.
|
||||
:param key: The key of the model.
|
||||
:param submodel_type: The submodel of the model to get.
|
||||
:returns: An object representing the loaded model.
|
||||
"""
|
||||
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
return self._services.model_manager.get_model(
|
||||
model_name, base_model, model_type, submodel, context_data=self._context_data
|
||||
return self._services.model_manager.load.load_model_by_key(
|
||||
key=key, submodel_type=submodel_type, context_data=self._context_data
|
||||
)
|
||||
|
||||
def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
def load_by_attrs(
|
||||
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Loads a model by its attributes.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
"""
|
||||
return self._services.model_manager.load.load_model_by_attr(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
context_data=self._context_data,
|
||||
)
|
||||
|
||||
def get_config(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Gets a model's info, an dict-like object.
|
||||
|
||||
:param model_name: The name of the model to get.
|
||||
:param base_model: The base model of the model to get.
|
||||
:param model_type: The type of the model to get.
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.model_info(model_name, base_model, model_type)
|
||||
return self._services.model_manager.store.get_model(key=key)
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Gets a model's metadata, if it has any.
|
||||
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.store.get_metadata(key=key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""
|
||||
Searches for models by path.
|
||||
|
||||
:param path: The path to search for.
|
||||
"""
|
||||
return self._services.model_manager.store.search_by_path(path)
|
||||
|
||||
def search_by_attrs(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
) -> list[AnyModelConfig]:
|
||||
"""
|
||||
Searches for models by attributes.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
"""
|
||||
|
||||
return self._services.model_manager.store.search_by_attr(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_format=model_format,
|
||||
)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
|
@ -4,8 +4,8 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user