diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 8bd4a89f45..f59fca6ec4 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -8,15 +8,21 @@ from abc import ABC, abstractmethod from enum import Enum from inspect import signature from types import UnionType -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Protocol, Type, TypeVar, Union import semver +from PIL.Image import Image as ImageType from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator from pydantic.fields import _Unset from pydantic_core import PydanticUndefined +import torch from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage from invokeai.app.util.misc import uuid_string +from invokeai.backend.model_management.model_manager import ModelInfo +from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType if TYPE_CHECKING: from ..services.invocation_services import InvocationServices @@ -460,7 +466,123 @@ class UIConfigBase(BaseModel): ) +class GetImage(Protocol): + def __call__(self, name: str) -> ImageType: + ... + + +class SaveImage(Protocol): + def __call__(self, image: ImageType, category: ImageCategory = ImageCategory.GENERAL) -> str: + ... + + +class GetLatents(Protocol): + def __call__(self, name: str) -> torch.Tensor: + ... + + +class SaveLatents(Protocol): + def __call__(self, latents: torch.Tensor) -> str: + ... + + +class GetConditioning(Protocol): + def __call__(self, name: str) -> torch.Tensor: + ... + + +class SaveConditioning(Protocol): + def __call__(self, conditioning: torch.Tensor) -> str: + ... + + +class IsCanceled(Protocol): + def __call__(self) -> bool: + ... + + +class EmitDenoisingProgress(Protocol): + def __call__(self, progress_image: ProgressImage, step: int, order: int, total_steps: int) -> None: + ... + + +class GetModel(Protocol): + def __call__( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + ) -> ModelInfo: + ... + + +class ModelExists(Protocol): + def __call__( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + ) -> bool: + ... + + class InvocationContext: + def __init__( + self, + # context + queue_id: str, + queue_item_id: int, + queue_batch_id: str, + graph_execution_state_id: str, + source_node_id: str, + # methods + get_image: GetImage, + save_image: SaveImage, + get_latents: GetLatents, + save_latents: SaveLatents, + get_conditioning: GetConditioning, + save_conditioning: SaveConditioning, + is_canceled: IsCanceled, + get_model: GetModel, + emit_denoising_progress: EmitDenoisingProgress, + model_exists: ModelExists, + # services + config: InvokeAIAppConfig, + ) -> None: + # context + self.queue_id = queue_id + self.queue_item_id = queue_item_id + self.queue_batch_id = queue_batch_id + self.graph_execution_state_id = graph_execution_state_id + self.source_node_id = source_node_id + + # resource methods + self.get_image = get_image + self.save_image = save_image + self.get_latents = get_latents + self.save_latents = save_latents + self.get_conditioning = get_conditioning + self.save_conditioning = save_conditioning + + # execution state + self.is_canceled = is_canceled + + # models + self.get_model = get_model + self.model_exists = model_exists + + # events + self.emit_denoising_progress = emit_denoising_progress + + # services + self.config = config + + # misc + self.categories = ImageCategory + + +class AppInvocationContext: """Initialized and provided to on execution of invocations.""" services: InvocationServices @@ -468,6 +590,7 @@ class InvocationContext: queue_id: str queue_item_id: int queue_batch_id: str + source_node_id: str def __init__( self, @@ -476,12 +599,113 @@ class InvocationContext: queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, + source_node_id: str, ): self.services = services self.graph_execution_state_id = graph_execution_state_id self.queue_id = queue_id self.queue_item_id = queue_item_id self.queue_batch_id = queue_batch_id + self.source_node_id = source_node_id + + def get_restricted_context(self, invocation: BaseInvocation) -> InvocationContext: + def get_image(name: str) -> ImageType: + return self.services.images.get_pil_image(name) + + def save_image(image: ImageType, category: ImageCategory = ImageCategory.GENERAL) -> str: + metadata = getattr(invocation, "metadata") + workflow = getattr(invocation, "workflow") + + image_dto = self.services.images.create( + image=image, + image_origin=ResourceOrigin.INTERNAL, + image_category=category, + session_id=self.graph_execution_state_id, + node_id=invocation.id, + is_intermediate=invocation.is_intermediate, + metadata=metadata.model_dump() if metadata else None, + workflow=workflow, + ) + return image_dto.image_name + + def get_latents(name: str) -> torch.Tensor: + return self.services.latents.get(name) + + def save_latents(latents: torch.Tensor) -> str: + name = f"{self.graph_execution_state_id}__{invocation.id}" + self.services.latents.save(name=name, data=latents) + return name + + def get_conditioning(name: str) -> torch.Tensor: + return self.services.latents.get(name) + + def save_conditioning(conditioning: torch.Tensor) -> str: + name = f"{self.graph_execution_state_id}__{invocation.id}_conditioning" + self.services.latents.save(name=name, data=conditioning) + return name + + def is_canceled() -> bool: + return self.services.queue.is_canceled(self.graph_execution_state_id) + + def get_model( + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + submodel: Optional[SubModelType] = None, + ) -> ModelInfo: + return self.services.model_manager.get_model( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=submodel, + queue_id=self.queue_id, + queue_item_id=self.queue_item_id, + queue_batch_id=self.queue_batch_id, + graph_execution_state_id=self.graph_execution_state_id, + ) + + def model_exists( + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + ) -> bool: + return self.services.model_manager.model_exists(model_name, base_model, model_type) + + def emit_denoising_progress(progress_image: ProgressImage, step: int, order: int, total_steps: int) -> None: + self.services.events.emit_generator_progress( + queue_id=self.queue_id, + queue_item_id=self.queue_item_id, + queue_batch_id=self.queue_batch_id, + graph_execution_state_id=self.graph_execution_state_id, + node=invocation.model_dump(), + source_node_id=self.source_node_id, + progress_image=progress_image, + step=step, + order=order, + total_steps=total_steps, + ) + + return InvocationContext( + # context + queue_id=self.queue_id, + queue_item_id=self.queue_item_id, + queue_batch_id=self.queue_batch_id, + graph_execution_state_id=self.graph_execution_state_id, + source_node_id=self.source_node_id, + # methods + get_image=get_image, + save_image=save_image, + get_latents=get_latents, + save_latents=save_latents, + get_conditioning=get_conditioning, + save_conditioning=save_conditioning, + is_canceled=is_canceled, + emit_denoising_progress=emit_denoising_progress, + get_model=get_model, + model_exists=model_exists, + # services + config=self.services.configuration, + ) class BaseInvocationOutput(BaseModel): @@ -613,7 +837,7 @@ class BaseInvocation(ABC, BaseModel): """Invoke with provided context and return outputs.""" pass - def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + def invoke_internal(self, context: AppInvocationContext) -> BaseInvocationOutput: for field_name, field in self.model_fields.items(): if not field.json_schema_extra or callable(field.json_schema_extra): # something has gone terribly awry, we should always have this and it should be a dict @@ -635,7 +859,7 @@ class BaseInvocation(ABC, BaseModel): # skip node cache codepath if it's disabled if context.services.configuration.node_cache_size == 0: - return self.invoke(context) + return self.invoke(context.get_restricted_context(invocation=self)) output: BaseInvocationOutput if self.use_cache: @@ -643,7 +867,7 @@ class BaseInvocation(ABC, BaseModel): cached_value = context.services.invocation_cache.get(key) if cached_value is None: context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') - output = self.invoke(context) + output = self.invoke(context.get_restricted_context(invocation=self)) context.services.invocation_cache.save(key, output) return output else: @@ -651,7 +875,7 @@ class BaseInvocation(ABC, BaseModel): return cached_value else: context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') - return self.invoke(context) + return self.invoke(context.get_restricted_context(invocation=self)) def get_type(self) -> str: return self.model_fields["type"].default diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b3ebc92320..5fca53a5c9 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -66,25 +66,21 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( + tokenizer_info = context.get_model( **self.clip.tokenizer.model_dump(), - context=context, ) - text_encoder_info = context.services.model_manager.get_model( + text_encoder_info = context.get_model( **self.clip.text_encoder.model_dump(), - context=context, ) def _lora_loader(): for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.get_model(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): @@ -93,11 +89,10 @@ class CompelInvocation(BaseInvocation): ti_list.append( ( name, - context.services.model_manager.get_model( + context.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, - context=context, ).context.model, ) ) @@ -126,7 +121,7 @@ class CompelInvocation(BaseInvocation): conjunction = Compel.parse_prompt_string(self.prompt) - if context.services.configuration.log_tokenization: + if context.config.log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -147,8 +142,7 @@ class CompelInvocation(BaseInvocation): ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.save_conditioning(conditioning_data) return ConditioningOutput( conditioning=ConditioningField( diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 3a4f4eadac..c773c5b600 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -397,7 +397,7 @@ class ImageResizeInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.get_image(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -406,21 +406,12 @@ class ImageResizeInvocation(BaseInvocation): resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata.model_dump() if self.metadata else None, - workflow=self.workflow, - ) + image_name = context.save_image(image=resize_image) return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, + image=ImageField(image_name=image_name), + width=resize_image.width, + height=resize_image.height, ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 7ce0ae7a8a..933427104b 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -182,9 +182,8 @@ def get_scheduler( seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.get_model( + orig_scheduler_info = context.get_model( **scheduler_info.model_dump(), - context=context, ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -298,15 +297,12 @@ class DenoiseLatentsInvocation(BaseInvocation): def dispatch_progress( self, context: InvocationContext, - source_node_id: str, intermediate_state: PipelineIntermediateState, base_model: BaseModelType, ) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, base_model=base_model, ) @@ -317,11 +313,11 @@ class DenoiseLatentsInvocation(BaseInvocation): unet, seed, ) -> ConditioningData: - positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) + positive_cond_data = context.get_conditioning(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = c.extra_conditioning - negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) + negative_cond_data = context.get_conditioning(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( @@ -408,17 +404,16 @@ class DenoiseLatentsInvocation(BaseInvocation): controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_manager.get_model( + context.get_model( model_name=control_info.control_model.model_name, model_type=ModelType.ControlNet, base_model=control_info.control_model.base_model, - context=context, ) ) # control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.get_image(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -476,22 +471,20 @@ class DenoiseLatentsInvocation(BaseInvocation): conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.get_model( + context.get_model( model_name=single_ip_adapter.ip_adapter_model.model_name, model_type=ModelType.IPAdapter, base_model=single_ip_adapter.ip_adapter_model.base_model, - context=context, ) ) - image_encoder_model_info = context.services.model_manager.get_model( + image_encoder_model_info = context.get_model( model_name=single_ip_adapter.image_encoder_model.model_name, model_type=ModelType.CLIPVision, base_model=single_ip_adapter.image_encoder_model.base_model, - context=context, ) - input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name) + input_image = context.get_image(single_ip_adapter.image.image_name) # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -535,13 +528,12 @@ class DenoiseLatentsInvocation(BaseInvocation): t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.get_model( + t2i_adapter_model_info = context.get_model( model_name=t2i_adapter_field.t2i_adapter_model.model_name, model_type=ModelType.T2IAdapter, base_model=t2i_adapter_field.t2i_adapter_model.base_model, - context=context, ) - image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) + image = context.get_image(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: @@ -651,11 +643,11 @@ class DenoiseLatentsInvocation(BaseInvocation): seed = None noise = None if self.noise is not None: - noise = context.services.latents.get(self.noise.latents_name) + noise = context.get_latents(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.get_latents(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -681,26 +673,20 @@ class DenoiseLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) + self.dispatch_progress(context, state, self.unet.unet.base_model) def _lora_loader(): for lora in self.unet.loras: - lora_info = context.services.model_manager.get_model( + lora_info = context.get_model( **lora.model_dump(exclude={"weight"}), - context=context, ) yield (lora_info.context.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.get_model( + unet_info = context.get_model( **self.unet.unet.model_dump(), - context=context, ) with ( ExitStack() as exit_stack, @@ -775,9 +761,8 @@ class DenoiseLatentsInvocation(BaseInvocation): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, result_latents) - return build_latents_output(latents_name=name, latents=result_latents, seed=seed) + latents_name = context.save_latents(result_latents) + return build_latents_output(latents_name=latents_name, latents=result_latents, seed=seed) @invocation( @@ -808,11 +793,10 @@ class LatentsToImageInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.get_latents(self.latents.latents_name) - vae_info = context.services.model_manager.get_model( + vae_info = context.get_model( **self.vae.vae.model_dump(), - context=context, ) with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: @@ -842,7 +826,7 @@ class LatentsToImageInvocation(BaseInvocation): vae.to(dtype=torch.float16) latents = latents.half() - if self.tiled or context.services.configuration.tiled_decode: + if self.tiled or context.config.tiled_decode: vae.enable_tiling() else: vae.disable_tiling() @@ -866,21 +850,12 @@ class LatentsToImageInvocation(BaseInvocation): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata.model_dump() if self.metadata else None, - workflow=self.workflow, - ) + image_name = context.save_image(image, category=context.categories.GENERAL) return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, + image=ImageField(image_name=image_name), + width=image.width, + height=image.height, ) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index dfa1075d6e..d51bd70793 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -98,7 +98,7 @@ class MainModelLoaderInvocation(BaseInvocation): model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.model_exists( model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 3c1651a2f0..3871d5d6a0 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -124,6 +124,5 @@ class NoiseInvocation(BaseInvocation): seed=self.seed, use_cpu=self.use_cpu, ) - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, noise) - return build_noise_output(latents_name=name, latents=noise, seed=self.seed) + latents_name = context.save_latents(noise) + return build_noise_output(latents_name=latents_name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/shared.py b/invokeai/app/invocations/shared.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index c59fb678ef..dd81f6cde2 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -4,7 +4,7 @@ from threading import BoundedSemaphore, Event, Thread from typing import Optional import invokeai.backend.util.logging as logger -from invokeai.app.invocations.baseinvocation import InvocationContext +from invokeai.app.invocations.baseinvocation import AppInvocationContext from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem from ..invoker import Invoker @@ -96,18 +96,21 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Invoke try: graph_id = graph_execution_state.id + source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] + with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id): # use the internal invoke_internal(), which wraps the node's invoke() method, # which handles a few things: # - nodes that require a value, but get it only from a connection # - referencing the invocation cache instead of executing the node outputs = invocation.invoke_internal( - InvocationContext( + AppInvocationContext( services=self.__invoker.services, graph_execution_state_id=graph_execution_state.id, queue_item_id=queue_item.session_queue_item_id, queue_id=queue_item.session_queue_id, queue_batch_id=queue_item.session_queue_batch_id, + source_node_id=source_node_id, ) ) diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 4c2fc4c085..6fd749222e 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -48,9 +48,12 @@ class ModelManagerServiceBase(ABC): model_name: str, base_model: BaseModelType, model_type: ModelType, + queue_id: str, + queue_item_id: int, + queue_batch_id: str, + graph_execution_state_id: str, submodel: Optional[SubModelType] = None, node: Optional[BaseInvocation] = None, - context: Optional[InvocationContext] = None, ) -> ModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index cdb3e59a91..f9285b53a6 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -11,6 +11,7 @@ from pydantic import Field from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.invoker import Invoker from invokeai.backend.model_management import ( AddModelResult, BaseModelType, @@ -86,28 +87,35 @@ class ModelManagerService(ModelManagerServiceBase): ) logger.info("Model manager service initialized") + def start(self, invoker: Invoker) -> None: + self._invoker = invoker + def get_model( self, model_name: str, base_model: BaseModelType, model_type: ModelType, + queue_id: str, + queue_item_id: int, + queue_batch_id: str, + graph_execution_state_id: str, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, ) -> ModelInfo: """ Retrieve the indicated model. submodel can be used to get a part (such as the vae) of a diffusers mode. """ - # we can emit model loading events if we are executing with access to the invocation context - if context: - self._emit_load_event( - context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - ) + self._emit_load_event( + queue_id=queue_id, + queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, + graph_execution_state_id=graph_execution_state_id, + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=submodel, + ) model_info = self.mgr.get_model( model_name, @@ -116,15 +124,17 @@ class ModelManagerService(ModelManagerServiceBase): submodel, ) - if context: - self._emit_load_event( - context=context, - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - model_info=model_info, - ) + self._emit_load_event( + queue_id=queue_id, + queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, + graph_execution_state_id=graph_execution_state_id, + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=submodel, + model_info=model_info, + ) return model_info @@ -263,22 +273,25 @@ class ModelManagerService(ModelManagerServiceBase): def _emit_load_event( self, - context: InvocationContext, model_name: str, base_model: BaseModelType, model_type: ModelType, + queue_id: str, + queue_item_id: int, + queue_batch_id: str, + graph_execution_state_id: str, submodel: Optional[SubModelType] = None, model_info: Optional[ModelInfo] = None, ): - if context.services.queue.is_canceled(context.graph_execution_state_id): + if self._invoker.services.queue.is_canceled(graph_execution_state_id): raise CanceledException() if model_info: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_completed( + queue_id=queue_id, + queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, + graph_execution_state_id=graph_execution_state_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -286,11 +299,11 @@ class ModelManagerService(ModelManagerServiceBase): model_info=model_info, ) else: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_started( + queue_id=queue_id, + queue_item_id=queue_item_id, + queue_batch_id=queue_batch_id, + graph_execution_state_id=graph_execution_state_id, model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index f166206d52..c5869124da 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -27,11 +27,9 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix= def stable_diffusion_step_callback( context: InvocationContext, intermediate_state: PipelineIntermediateState, - node: dict, - source_node_id: str, base_model: BaseModelType, ): - if context.services.queue.is_canceled(context.graph_execution_state_id): + if context.is_canceled(): raise CanceledException # Some schedulers report not only the noisy latents at the current timestep, @@ -108,13 +106,7 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - context.services.events.emit_generator_progress( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, - node=node, - source_node_id=source_node_id, + context.emit_denoising_progress( progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), step=intermediate_state.step, order=intermediate_state.order,