mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): WIP restricted invocation context
This commit is contained in:
@ -8,15 +8,21 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Protocol, Type, TypeVar, Union
|
||||
|
||||
import semver
|
||||
from PIL.Image import Image as ImageType
|
||||
from pydantic import BaseModel, ConfigDict, Field, create_model, field_validator
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@ -460,7 +466,123 @@ class UIConfigBase(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class GetImage(Protocol):
|
||||
def __call__(self, name: str) -> ImageType:
|
||||
...
|
||||
|
||||
|
||||
class SaveImage(Protocol):
|
||||
def __call__(self, image: ImageType, category: ImageCategory = ImageCategory.GENERAL) -> str:
|
||||
...
|
||||
|
||||
|
||||
class GetLatents(Protocol):
|
||||
def __call__(self, name: str) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
class SaveLatents(Protocol):
|
||||
def __call__(self, latents: torch.Tensor) -> str:
|
||||
...
|
||||
|
||||
|
||||
class GetConditioning(Protocol):
|
||||
def __call__(self, name: str) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
class SaveConditioning(Protocol):
|
||||
def __call__(self, conditioning: torch.Tensor) -> str:
|
||||
...
|
||||
|
||||
|
||||
class IsCanceled(Protocol):
|
||||
def __call__(self) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class EmitDenoisingProgress(Protocol):
|
||||
def __call__(self, progress_image: ProgressImage, step: int, order: int, total_steps: int) -> None:
|
||||
...
|
||||
|
||||
|
||||
class GetModel(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
) -> ModelInfo:
|
||||
...
|
||||
|
||||
|
||||
class ModelExists(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
def __init__(
|
||||
self,
|
||||
# context
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
source_node_id: str,
|
||||
# methods
|
||||
get_image: GetImage,
|
||||
save_image: SaveImage,
|
||||
get_latents: GetLatents,
|
||||
save_latents: SaveLatents,
|
||||
get_conditioning: GetConditioning,
|
||||
save_conditioning: SaveConditioning,
|
||||
is_canceled: IsCanceled,
|
||||
get_model: GetModel,
|
||||
emit_denoising_progress: EmitDenoisingProgress,
|
||||
model_exists: ModelExists,
|
||||
# services
|
||||
config: InvokeAIAppConfig,
|
||||
) -> None:
|
||||
# context
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
self.queue_batch_id = queue_batch_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.source_node_id = source_node_id
|
||||
|
||||
# resource methods
|
||||
self.get_image = get_image
|
||||
self.save_image = save_image
|
||||
self.get_latents = get_latents
|
||||
self.save_latents = save_latents
|
||||
self.get_conditioning = get_conditioning
|
||||
self.save_conditioning = save_conditioning
|
||||
|
||||
# execution state
|
||||
self.is_canceled = is_canceled
|
||||
|
||||
# models
|
||||
self.get_model = get_model
|
||||
self.model_exists = model_exists
|
||||
|
||||
# events
|
||||
self.emit_denoising_progress = emit_denoising_progress
|
||||
|
||||
# services
|
||||
self.config = config
|
||||
|
||||
# misc
|
||||
self.categories = ImageCategory
|
||||
|
||||
|
||||
class AppInvocationContext:
|
||||
"""Initialized and provided to on execution of invocations."""
|
||||
|
||||
services: InvocationServices
|
||||
@ -468,6 +590,7 @@ class InvocationContext:
|
||||
queue_id: str
|
||||
queue_item_id: int
|
||||
queue_batch_id: str
|
||||
source_node_id: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -476,12 +599,113 @@ class InvocationContext:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
source_node_id: str,
|
||||
):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
self.queue_batch_id = queue_batch_id
|
||||
self.source_node_id = source_node_id
|
||||
|
||||
def get_restricted_context(self, invocation: BaseInvocation) -> InvocationContext:
|
||||
def get_image(name: str) -> ImageType:
|
||||
return self.services.images.get_pil_image(name)
|
||||
|
||||
def save_image(image: ImageType, category: ImageCategory = ImageCategory.GENERAL) -> str:
|
||||
metadata = getattr(invocation, "metadata")
|
||||
workflow = getattr(invocation, "workflow")
|
||||
|
||||
image_dto = self.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=category,
|
||||
session_id=self.graph_execution_state_id,
|
||||
node_id=invocation.id,
|
||||
is_intermediate=invocation.is_intermediate,
|
||||
metadata=metadata.model_dump() if metadata else None,
|
||||
workflow=workflow,
|
||||
)
|
||||
return image_dto.image_name
|
||||
|
||||
def get_latents(name: str) -> torch.Tensor:
|
||||
return self.services.latents.get(name)
|
||||
|
||||
def save_latents(latents: torch.Tensor) -> str:
|
||||
name = f"{self.graph_execution_state_id}__{invocation.id}"
|
||||
self.services.latents.save(name=name, data=latents)
|
||||
return name
|
||||
|
||||
def get_conditioning(name: str) -> torch.Tensor:
|
||||
return self.services.latents.get(name)
|
||||
|
||||
def save_conditioning(conditioning: torch.Tensor) -> str:
|
||||
name = f"{self.graph_execution_state_id}__{invocation.id}_conditioning"
|
||||
self.services.latents.save(name=name, data=conditioning)
|
||||
return name
|
||||
|
||||
def is_canceled() -> bool:
|
||||
return self.services.queue.is_canceled(self.graph_execution_state_id)
|
||||
|
||||
def get_model(
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
) -> ModelInfo:
|
||||
return self.services.model_manager.get_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
queue_id=self.queue_id,
|
||||
queue_item_id=self.queue_item_id,
|
||||
queue_batch_id=self.queue_batch_id,
|
||||
graph_execution_state_id=self.graph_execution_state_id,
|
||||
)
|
||||
|
||||
def model_exists(
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
return self.services.model_manager.model_exists(model_name, base_model, model_type)
|
||||
|
||||
def emit_denoising_progress(progress_image: ProgressImage, step: int, order: int, total_steps: int) -> None:
|
||||
self.services.events.emit_generator_progress(
|
||||
queue_id=self.queue_id,
|
||||
queue_item_id=self.queue_item_id,
|
||||
queue_batch_id=self.queue_batch_id,
|
||||
graph_execution_state_id=self.graph_execution_state_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=self.source_node_id,
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
order=order,
|
||||
total_steps=total_steps,
|
||||
)
|
||||
|
||||
return InvocationContext(
|
||||
# context
|
||||
queue_id=self.queue_id,
|
||||
queue_item_id=self.queue_item_id,
|
||||
queue_batch_id=self.queue_batch_id,
|
||||
graph_execution_state_id=self.graph_execution_state_id,
|
||||
source_node_id=self.source_node_id,
|
||||
# methods
|
||||
get_image=get_image,
|
||||
save_image=save_image,
|
||||
get_latents=get_latents,
|
||||
save_latents=save_latents,
|
||||
get_conditioning=get_conditioning,
|
||||
save_conditioning=save_conditioning,
|
||||
is_canceled=is_canceled,
|
||||
emit_denoising_progress=emit_denoising_progress,
|
||||
get_model=get_model,
|
||||
model_exists=model_exists,
|
||||
# services
|
||||
config=self.services.configuration,
|
||||
)
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
@ -613,7 +837,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
def invoke_internal(self, context: AppInvocationContext) -> BaseInvocationOutput:
|
||||
for field_name, field in self.model_fields.items():
|
||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||
# something has gone terribly awry, we should always have this and it should be a dict
|
||||
@ -635,7 +859,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if context.services.configuration.node_cache_size == 0:
|
||||
return self.invoke(context)
|
||||
return self.invoke(context.get_restricted_context(invocation=self))
|
||||
|
||||
output: BaseInvocationOutput
|
||||
if self.use_cache:
|
||||
@ -643,7 +867,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
cached_value = context.services.invocation_cache.get(key)
|
||||
if cached_value is None:
|
||||
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
output = self.invoke(context)
|
||||
output = self.invoke(context.get_restricted_context(invocation=self))
|
||||
context.services.invocation_cache.save(key, output)
|
||||
return output
|
||||
else:
|
||||
@ -651,7 +875,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return cached_value
|
||||
else:
|
||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
return self.invoke(context)
|
||||
return self.invoke(context.get_restricted_context(invocation=self))
|
||||
|
||||
def get_type(self) -> str:
|
||||
return self.model_fields["type"].default
|
||||
|
@ -66,25 +66,21 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.get_model(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.get_model(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
lora_info = context.get_model(**lora.model_dump(exclude={"weight"}))
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
@ -93,11 +89,10 @@ class CompelInvocation(BaseInvocation):
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
@ -126,7 +121,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
|
||||
if context.services.configuration.log_tokenization:
|
||||
if context.config.log_tokenization:
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
@ -147,8 +142,7 @@ class CompelInvocation(BaseInvocation):
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
conditioning_name = context.save_conditioning(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
|
@ -397,7 +397,7 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = context.get_image(self.image.image_name)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
@ -406,21 +406,12 @@ class ImageResizeInvocation(BaseInvocation):
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=resize_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
image_name = context.save_image(image=resize_image)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_name=image_name),
|
||||
width=resize_image.width,
|
||||
height=resize_image.height,
|
||||
)
|
||||
|
||||
|
||||
|
@ -182,9 +182,8 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
orig_scheduler_info = context.get_model(
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
@ -298,15 +297,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
base_model: BaseModelType,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
@ -317,11 +313,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet,
|
||||
seed,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
positive_cond_data = context.get_conditioning(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
negative_cond_data = context.get_conditioning(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
@ -408,17 +404,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
context.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
input_image = context.get_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@ -476,22 +471,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
context.get_model(
|
||||
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
||||
model_type=ModelType.IPAdapter,
|
||||
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.services.model_manager.get_model(
|
||||
image_encoder_model_info = context.get_model(
|
||||
model_name=single_ip_adapter.image_encoder_model.model_name,
|
||||
model_type=ModelType.CLIPVision,
|
||||
base_model=single_ip_adapter.image_encoder_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
|
||||
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
|
||||
input_image = context.get_image(single_ip_adapter.image.image_name)
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
@ -535,13 +528,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_info = context.services.model_manager.get_model(
|
||||
t2i_adapter_model_info = context.get_model(
|
||||
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||
model_type=ModelType.T2IAdapter,
|
||||
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
||||
image = context.get_image(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||
@ -651,11 +643,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
noise = context.get_latents(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
latents = context.get_latents(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
@ -681,26 +673,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
self.dispatch_progress(context, state, self.unet.unet.base_model)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
lora_info = context.get_model(
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
unet_info = context.get_model(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
@ -775,9 +761,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||
latents_name = context.save_latents(result_latents)
|
||||
return build_latents_output(latents_name=latents_name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -808,11 +793,10 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
latents = context.get_latents(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.get_model(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
||||
@ -842,7 +826,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if self.tiled or context.services.configuration.tiled_decode:
|
||||
if self.tiled or context.config.tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
@ -866,21 +850,12 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata.model_dump() if self.metadata else None,
|
||||
workflow=self.workflow,
|
||||
)
|
||||
image_name = context.save_image(image, category=context.categories.GENERAL)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
image=ImageField(image_name=image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
|
||||
|
||||
|
@ -98,7 +98,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
model_type = ModelType.Main
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
if not context.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
|
@ -124,6 +124,5 @@ class NoiseInvocation(BaseInvocation):
|
||||
seed=self.seed,
|
||||
use_cpu=self.use_cpu,
|
||||
)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
||||
latents_name = context.save_latents(noise)
|
||||
return build_noise_output(latents_name=latents_name, latents=noise, seed=self.seed)
|
||||
|
0
invokeai/app/invocations/shared.py
Normal file
0
invokeai/app/invocations/shared.py
Normal file
@ -4,7 +4,7 @@ from threading import BoundedSemaphore, Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.invocations.baseinvocation import AppInvocationContext
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
|
||||
from ..invoker import Invoker
|
||||
@ -96,18 +96,21 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# Invoke
|
||||
try:
|
||||
graph_id = graph_execution_state.id
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
|
||||
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||
# which handles a few things:
|
||||
# - nodes that require a value, but get it only from a connection
|
||||
# - referencing the invocation cache instead of executing the node
|
||||
outputs = invocation.invoke_internal(
|
||||
InvocationContext(
|
||||
AppInvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -48,9 +48,12 @@ class ModelManagerServiceBase(ABC):
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
|
@ -11,6 +11,7 @@ from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
@ -86,28 +87,35 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
)
|
||||
logger.info("Model manager service initialized")
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
"""
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
self._emit_load_event(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
@ -116,15 +124,17 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
submodel,
|
||||
)
|
||||
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
self._emit_load_event(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
@ -263,22 +273,25 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
if self._invoker.services.queue.is_canceled(graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@ -286,11 +299,11 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_info=model_info,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
|
@ -27,11 +27,9 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
|
||||
def stable_diffusion_step_callback(
|
||||
context: InvocationContext,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
base_model: BaseModelType,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
if context.is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
@ -108,13 +106,7 @@ def stable_diffusion_step_callback(
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
context.emit_denoising_progress(
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
order=intermediate_state.order,
|
||||
|
Reference in New Issue
Block a user