feat(nodes): update all invocations to use new invocation context

Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them.

Supporting minor changes:
- Patch bump for all nodes that use the context
- Update invocation processor to provide new context
- Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node
- Minor change to `ModelManagerService` to support the new wrapped context
- Fanagling of imports to avoid circular dependencies
This commit is contained in:
psychedelicious
2024-01-13 23:23:16 +11:00
parent 97a6c6eea7
commit 7e5ba2795e
32 changed files with 716 additions and 1191 deletions

View File

@ -1,25 +1,18 @@
from typing import Protocol
from typing import TYPE_CHECKING
import torch
from PIL import Image
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
from invokeai.app.services.shared.invocation_context import InvocationContextData
from ...backend.model_management.models import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.util.util import image_to_dataURL
class StepCallback(Protocol):
def __call__(
self,
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
) -> None:
...
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
from invokeai.app.services.shared.invocation_context import InvocationContextData
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
@ -38,11 +31,11 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
def stable_diffusion_step_callback(
context_data: InvocationContextData,
context_data: "InvocationContextData",
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
invocation_queue: InvocationQueueABC,
events: EventServiceBase,
invocation_queue: "InvocationQueueABC",
events: "EventServiceBase",
) -> None:
if invocation_queue.is_canceled(context_data.session_id):
raise CanceledException