feat(nodes): WIP restricted invocation context

This commit is contained in:
psychedelicious
2023-10-16 20:40:07 +11:00
parent 0aedd6d9f0
commit 09609cd553
11 changed files with 323 additions and 129 deletions

View File

@ -8,15 +8,21 @@ from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Protocol, Type, TypeVar, Union
import semver
from PIL.Image import Image as ImageType
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
import torch
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_management.model_manager import ModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
@ -460,7 +466,123 @@ class UIConfigBase(BaseModel):
)
class GetImage(Protocol):
def __call__(self, name: str) -> ImageType:
...
class SaveImage(Protocol):
def __call__(self, image: ImageType, category: ImageCategory = ImageCategory.GENERAL) -> str:
...
class GetLatents(Protocol):
def __call__(self, name: str) -> torch.Tensor:
...
class SaveLatents(Protocol):
def __call__(self, latents: torch.Tensor) -> str:
...
class GetConditioning(Protocol):
def __call__(self, name: str) -> torch.Tensor:
...
class SaveConditioning(Protocol):
def __call__(self, conditioning: torch.Tensor) -> str:
...
class IsCanceled(Protocol):
def __call__(self) -> bool:
...
class EmitDenoisingProgress(Protocol):
def __call__(self, progress_image: ProgressImage, step: int, order: int, total_steps: int) -> None:
...
class GetModel(Protocol):
def __call__(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
) -> ModelInfo:
...
class ModelExists(Protocol):
def __call__(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
...
class InvocationContext:
def __init__(
self,
# context
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
source_node_id: str,
# methods
get_image: GetImage,
save_image: SaveImage,
get_latents: GetLatents,
save_latents: SaveLatents,
get_conditioning: GetConditioning,
save_conditioning: SaveConditioning,
is_canceled: IsCanceled,
get_model: GetModel,
emit_denoising_progress: EmitDenoisingProgress,
model_exists: ModelExists,
# services
config: InvokeAIAppConfig,
) -> None:
# context
self.queue_id = queue_id
self.queue_item_id = queue_item_id
self.queue_batch_id = queue_batch_id
self.graph_execution_state_id = graph_execution_state_id
self.source_node_id = source_node_id
# resource methods
self.get_image = get_image
self.save_image = save_image
self.get_latents = get_latents
self.save_latents = save_latents
self.get_conditioning = get_conditioning
self.save_conditioning = save_conditioning
# execution state
self.is_canceled = is_canceled
# models
self.get_model = get_model
self.model_exists = model_exists
# events
self.emit_denoising_progress = emit_denoising_progress
# services
self.config = config
# misc
self.categories = ImageCategory
class AppInvocationContext:
"""Initialized and provided to on execution of invocations."""
services: InvocationServices
@ -468,6 +590,7 @@ class InvocationContext:
queue_id: str
queue_item_id: int
queue_batch_id: str
source_node_id: str
def __init__(
self,
@ -476,12 +599,113 @@ class InvocationContext:
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
source_node_id: str,
):
self.services = services
self.graph_execution_state_id = graph_execution_state_id
self.queue_id = queue_id
self.queue_item_id = queue_item_id
self.queue_batch_id = queue_batch_id
self.source_node_id = source_node_id
def get_restricted_context(self, invocation: BaseInvocation) -> InvocationContext:
def get_image(name: str) -> ImageType:
return self.services.images.get_pil_image(name)
def save_image(image: ImageType, category: ImageCategory = ImageCategory.GENERAL) -> str:
metadata = getattr(invocation, "metadata")
workflow = getattr(invocation, "workflow")
image_dto = self.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=category,
session_id=self.graph_execution_state_id,
node_id=invocation.id,
is_intermediate=invocation.is_intermediate,
metadata=metadata.model_dump() if metadata else None,
workflow=workflow,
)
return image_dto.image_name
def get_latents(name: str) -> torch.Tensor:
return self.services.latents.get(name)
def save_latents(latents: torch.Tensor) -> str:
name = f"{self.graph_execution_state_id}__{invocation.id}"
self.services.latents.save(name=name, data=latents)
return name
def get_conditioning(name: str) -> torch.Tensor:
return self.services.latents.get(name)
def save_conditioning(conditioning: torch.Tensor) -> str:
name = f"{self.graph_execution_state_id}__{invocation.id}_conditioning"
self.services.latents.save(name=name, data=conditioning)
return name
def is_canceled() -> bool:
return self.services.queue.is_canceled(self.graph_execution_state_id)
def get_model(
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: Optional[SubModelType] = None,
) -> ModelInfo:
return self.services.model_manager.get_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
queue_id=self.queue_id,
queue_item_id=self.queue_item_id,
queue_batch_id=self.queue_batch_id,
graph_execution_state_id=self.graph_execution_state_id,
)
def model_exists(
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
return self.services.model_manager.model_exists(model_name, base_model, model_type)
def emit_denoising_progress(progress_image: ProgressImage, step: int, order: int, total_steps: int) -> None:
self.services.events.emit_generator_progress(
queue_id=self.queue_id,
queue_item_id=self.queue_item_id,
queue_batch_id=self.queue_batch_id,
graph_execution_state_id=self.graph_execution_state_id,
node=invocation.model_dump(),
source_node_id=self.source_node_id,
progress_image=progress_image,
step=step,
order=order,
total_steps=total_steps,
)
return InvocationContext(
# context
queue_id=self.queue_id,
queue_item_id=self.queue_item_id,
queue_batch_id=self.queue_batch_id,
graph_execution_state_id=self.graph_execution_state_id,
source_node_id=self.source_node_id,
# methods
get_image=get_image,
save_image=save_image,
get_latents=get_latents,
save_latents=save_latents,
get_conditioning=get_conditioning,
save_conditioning=save_conditioning,
is_canceled=is_canceled,
emit_denoising_progress=emit_denoising_progress,
get_model=get_model,
model_exists=model_exists,
# services
config=self.services.configuration,
)
class BaseInvocationOutput(BaseModel):
@ -613,7 +837,7 @@ class BaseInvocation(ABC, BaseModel):
"""Invoke with provided context and return outputs."""
pass
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
def invoke_internal(self, context: AppInvocationContext) -> BaseInvocationOutput:
for field_name, field in self.model_fields.items():
if not field.json_schema_extra or callable(field.json_schema_extra):
# something has gone terribly awry, we should always have this and it should be a dict
@ -635,7 +859,7 @@ class BaseInvocation(ABC, BaseModel):
# skip node cache codepath if it's disabled
if context.services.configuration.node_cache_size == 0:
return self.invoke(context)
return self.invoke(context.get_restricted_context(invocation=self))
output: BaseInvocationOutput
if self.use_cache:
@ -643,7 +867,7 @@ class BaseInvocation(ABC, BaseModel):
cached_value = context.services.invocation_cache.get(key)
if cached_value is None:
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
output = self.invoke(context)
output = self.invoke(context.get_restricted_context(invocation=self))
context.services.invocation_cache.save(key, output)
return output
else:
@ -651,7 +875,7 @@ class BaseInvocation(ABC, BaseModel):
return cached_value
else:
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
return self.invoke(context)
return self.invoke(context.get_restricted_context(invocation=self))
def get_type(self) -> str:
return self.model_fields["type"].default

View File

@ -66,25 +66,21 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.get_model(
tokenizer_info = context.get_model(
**self.clip.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
text_encoder_info = context.get_model(
**self.clip.text_encoder.model_dump(),
context=context,
)
def _lora_loader():
for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model(
**lora.model_dump(exclude={"weight"}), context=context
)
lora_info = context.get_model(**lora.model_dump(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
# loras = [(context.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
@ -93,11 +89,10 @@ class CompelInvocation(BaseInvocation):
ti_list.append(
(
name,
context.services.model_manager.get_model(
context.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model,
)
)
@ -126,7 +121,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt)
if context.services.configuration.log_tokenization:
if context.config.log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -147,8 +142,7 @@ class CompelInvocation(BaseInvocation):
]
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
conditioning_name = context.save_conditioning(conditioning_data)
return ConditioningOutput(
conditioning=ConditioningField(

View File

@ -397,7 +397,7 @@ class ImageResizeInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
image = context.get_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -406,21 +406,12 @@ class ImageResizeInvocation(BaseInvocation):
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.model_dump() if self.metadata else None,
workflow=self.workflow,
)
image_name = context.save_image(image=resize_image)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
image=ImageField(image_name=image_name),
width=resize_image.width,
height=resize_image.height,
)

View File

@ -182,9 +182,8 @@ def get_scheduler(
seed: int,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.services.model_manager.get_model(
orig_scheduler_info = context.get_model(
**scheduler_info.model_dump(),
context=context,
)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@ -298,15 +297,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.model_dump(),
source_node_id=source_node_id,
base_model=base_model,
)
@ -317,11 +313,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet,
seed,
) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
positive_cond_data = context.get_conditioning(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = c.extra_conditioning
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
negative_cond_data = context.get_conditioning(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData(
@ -408,17 +404,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
context.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
)
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
input_image = context.get_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@ -476,22 +471,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
context.get_model(
model_name=single_ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=single_ip_adapter.ip_adapter_model.base_model,
context=context,
)
)
image_encoder_model_info = context.services.model_manager.get_model(
image_encoder_model_info = context.get_model(
model_name=single_ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=single_ip_adapter.image_encoder_model.base_model,
context=context,
)
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
input_image = context.get_image(single_ip_adapter.image.image_name)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
@ -535,13 +528,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.services.model_manager.get_model(
t2i_adapter_model_info = context.get_model(
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
model_type=ModelType.T2IAdapter,
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
context=context,
)
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
image = context.get_image(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
@ -651,11 +643,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed = None
noise = None
if self.noise is not None:
noise = context.services.latents.get(self.noise.latents_name)
noise = context.get_latents(self.noise.latents_name)
seed = self.noise.seed
if self.latents is not None:
latents = context.services.latents.get(self.latents.latents_name)
latents = context.get_latents(self.latents.latents_name)
if seed is None:
seed = self.latents.seed
@ -681,26 +673,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
self.dispatch_progress(context, state, self.unet.unet.base_model)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
lora_info = context.get_model(
**lora.model_dump(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
unet_info = context.get_model(
**self.unet.unet.model_dump(),
context=context,
)
with (
ExitStack() as exit_stack,
@ -775,9 +761,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
latents_name = context.save_latents(result_latents)
return build_latents_output(latents_name=latents_name, latents=result_latents, seed=seed)
@invocation(
@ -808,11 +793,10 @@ class LatentsToImageInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
latents = context.get_latents(self.latents.latents_name)
vae_info = context.services.model_manager.get_model(
vae_info = context.get_model(
**self.vae.vae.model_dump(),
context=context,
)
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
@ -842,7 +826,7 @@ class LatentsToImageInvocation(BaseInvocation):
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.services.configuration.tiled_decode:
if self.tiled or context.config.tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
@ -866,21 +850,12 @@ class LatentsToImageInvocation(BaseInvocation):
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.model_dump() if self.metadata else None,
workflow=self.workflow,
)
image_name = context.save_image(image, category=context.categories.GENERAL)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
image=ImageField(image_name=image_name),
width=image.width,
height=image.height,
)

View File

@ -98,7 +98,7 @@ class MainModelLoaderInvocation(BaseInvocation):
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
if not context.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,

View File

@ -124,6 +124,5 @@ class NoiseInvocation(BaseInvocation):
seed=self.seed,
use_cpu=self.use_cpu,
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
latents_name = context.save_latents(noise)
return build_noise_output(latents_name=latents_name, latents=noise, seed=self.seed)

View File

View File

@ -4,7 +4,7 @@ from threading import BoundedSemaphore, Event, Thread
from typing import Optional
import invokeai.backend.util.logging as logger
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.app.invocations.baseinvocation import AppInvocationContext
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
from ..invoker import Invoker
@ -96,18 +96,21 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
graph_id = graph_execution_state.id
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
# use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things:
# - nodes that require a value, but get it only from a connection
# - referencing the invocation cache instead of executing the node
outputs = invocation.invoke_internal(
InvocationContext(
AppInvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
queue_batch_id=queue_item.session_queue_batch_id,
source_node_id=source_node_id,
)
)

View File

@ -48,9 +48,12 @@ class ModelManagerServiceBase(ABC):
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
submodel: Optional[SubModelType] = None,
node: Optional[BaseInvocation] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""Retrieve the indicated model with name and type.
submodel can be used to get a part (such as the vae)

View File

@ -11,6 +11,7 @@ from pydantic import Field
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_management import (
AddModelResult,
BaseModelType,
@ -86,28 +87,35 @@ class ModelManagerService(ModelManagerServiceBase):
)
logger.info("Model manager service initialized")
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def get_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
submodel: Optional[SubModelType] = None,
context: Optional[InvocationContext] = None,
) -> ModelInfo:
"""
Retrieve the indicated model. submodel can be used to get a
part (such as the vae) of a diffusers mode.
"""
# we can emit model loading events if we are executing with access to the invocation context
if context:
self._emit_load_event(
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
self._emit_load_event(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
)
model_info = self.mgr.get_model(
model_name,
@ -116,15 +124,17 @@ class ModelManagerService(ModelManagerServiceBase):
submodel,
)
if context:
self._emit_load_event(
context=context,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,
)
self._emit_load_event(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,
)
return model_info
@ -263,22 +273,25 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event(
self,
context: InvocationContext,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
submodel: Optional[SubModelType] = None,
model_info: Optional[ModelInfo] = None,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
if self._invoker.services.queue.is_canceled(graph_execution_state_id):
raise CanceledException()
if model_info:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
self._invoker.services.events.emit_model_load_completed(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
@ -286,11 +299,11 @@ class ModelManagerService(ModelManagerServiceBase):
model_info=model_info,
)
else:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
self._invoker.services.events.emit_model_load_started(
queue_id=queue_id,
queue_item_id=queue_item_id,
queue_batch_id=queue_batch_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,

View File

@ -27,11 +27,9 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
def stable_diffusion_step_callback(
context: InvocationContext,
intermediate_state: PipelineIntermediateState,
node: dict,
source_node_id: str,
base_model: BaseModelType,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
if context.is_canceled():
raise CanceledException
# Some schedulers report not only the noisy latents at the current timestep,
@ -108,13 +106,7 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
context.emit_denoising_progress(
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,