feat(nodes): restricts invocation context power

Creates a low-power `InvocationContext` with simplified methods and data.

See `invocation_context.py` for detailed comments.
This commit is contained in:
psychedelicious
2024-01-13 18:02:58 +11:00
parent 992b02aa65
commit 3d98446d5d
2 changed files with 434 additions and 13 deletions

View File

@ -1,12 +1,25 @@
from typing import Protocol
import torch
from PIL import Image
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
from invokeai.app.services.shared.invocation_context import InvocationContextData
from ...backend.model_management.models import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
from ..invocations.baseinvocation import InvocationContext
class StepCallback(Protocol):
def __call__(
self,
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
) -> None:
...
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
@ -25,13 +38,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
def stable_diffusion_step_callback(
context: InvocationContext,
context_data: InvocationContextData,
intermediate_state: PipelineIntermediateState,
node: dict,
source_node_id: str,
base_model: BaseModelType,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
invocation_queue: InvocationQueueABC,
events: EventServiceBase,
) -> None:
if invocation_queue.is_canceled(context_data.session_id):
raise CanceledException
# Some schedulers report not only the noisy latents at the current timestep,
@ -108,13 +121,13 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
queue_batch_id=context.queue_batch_id,
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
events.emit_generator_progress(
queue_id=context_data.queue_id,
queue_item_id=context_data.queue_item_id,
queue_batch_id=context_data.batch_id,
graph_execution_state_id=context_data.session_id,
node_id=context_data.invocation.id,
source_node_id=context_data.source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,