diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 395d5e9870..c4aed1fac5 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -16,10 +16,16 @@ from pydantic import BaseModel, ConfigDict, Field, create_model from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined -from invokeai.app.invocations.fields import FieldKind, Input +from invokeai.app.invocations.fields import ( + FieldDescriptions, + FieldKind, + Input, + InputFieldJSONSchemaExtra, + MetadataField, + logger, +) from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID -from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string from invokeai.backend.util.logging import InvokeAILogger @@ -219,7 +225,7 @@ class BaseInvocation(ABC, BaseModel): """Invoke with provided context and return outputs.""" pass - def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput: """ Internal invoke method, calls `invoke()` after some prep. Handles optional fields that are required to call `invoke()` and invocation cache. @@ -244,23 +250,23 @@ class BaseInvocation(ABC, BaseModel): raise MissingInputException(self.model_fields["type"].default, field_name) # skip node cache codepath if it's disabled - if context.services.configuration.node_cache_size == 0: + if services.configuration.node_cache_size == 0: return self.invoke(context) output: BaseInvocationOutput if self.use_cache: - key = context.services.invocation_cache.create_key(self) - cached_value = context.services.invocation_cache.get(key) + key = services.invocation_cache.create_key(self) + cached_value = services.invocation_cache.get(key) if cached_value is None: - context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}') output = self.invoke(context) - context.services.invocation_cache.save(key, output) + services.invocation_cache.save(key, output) return output else: - context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') + services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}') return cached_value else: - context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') + services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') return self.invoke(context) id: str = Field( @@ -513,3 +519,29 @@ def invocation_output( return cls return wrapper + + +class WithMetadata(BaseModel): + """ + Inherit from this class if your node needs a metadata input field. + """ + + metadata: Optional[MetadataField] = Field( + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), + ) + + +class WithWorkflow: + workflow = None + + def __init_subclass__(cls) -> None: + logger.warn( + f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow." + ) + super().__init_subclass__() diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index d35a9d79c7..f5709b4ba3 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -7,7 +7,7 @@ from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.primitives import IntegerCollectionOutput from invokeai.app.util.misc import SEED_MAX -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -27,7 +27,7 @@ class RangeInvocation(BaseInvocation): raise ValueError("stop must be greater than start") return v - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) @@ -45,7 +45,7 @@ class RangeOfSizeInvocation(BaseInvocation): size: int = InputField(default=1, gt=0, description="The number of values") step: int = InputField(default=1, description="The step of the range") - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput( collection=list(range(self.start, self.start + (self.step * self.size), self.step)) ) @@ -72,6 +72,6 @@ class RandomRangeInvocation(BaseInvocation): description="The seed for the RNG (omit for random)", ) - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: rng = np.random.default_rng(self.seed) return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b386aef2cb..b4496031bc 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,12 +1,18 @@ -from dataclasses import dataclass -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput +from invokeai.app.invocations.fields import ( + ConditioningFieldData, + FieldDescriptions, + Input, + InputField, + OutputField, + UIComponent, +) +from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ExtraConditioningInfo, @@ -20,16 +26,14 @@ from ..util.ti_utils import extract_ti_triggers_from_prompt from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from .model import ClipField +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] # unconditioned: Optional[torch.Tensor] @@ -44,7 +48,7 @@ class ConditioningFieldData: title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -61,26 +65,18 @@ class CompelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - context=context, - ) + def invoke(self, context) -> ConditioningOutput: + tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) def _lora_loader(): for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(self.prompt): @@ -89,11 +85,10 @@ class CompelInvocation(BaseInvocation): ti_list.append( ( name, - context.services.model_manager.get_model( + context.models.load( model_name=name, base_model=self.clip.text_encoder.base_model, model_type=ModelType.TextualInversion, - context=context, ).context.model, ) ) @@ -124,7 +119,7 @@ class CompelInvocation(BaseInvocation): conjunction = Compel.parse_prompt_string(self.prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -145,34 +140,23 @@ class CompelInvocation(BaseInvocation): ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) class SDXLPromptInvocationBase: def run_clip_compel( self, - context: InvocationContext, + context: "InvocationContext", clip_field: ClipField, prompt: str, get_pooled: bool, lora_prefix: str, zero_on_empty: bool, ): - tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.model_dump(), - context=context, - ) - text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.model_dump(), - context=context, - ) + tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) # return zero on empty if prompt == "" and zero_on_empty: @@ -196,14 +180,12 @@ class SDXLPromptInvocationBase: def _lora_loader(): for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(prompt): @@ -212,11 +194,10 @@ class SDXLPromptInvocationBase: ti_list.append( ( name, - context.services.model_manager.get_model( + context.models.load( model_name=name, base_model=clip_field.text_encoder.base_model, model_type=ModelType.TextualInversion, - context=context, ).context.model, ) ) @@ -249,7 +230,7 @@ class SDXLPromptInvocationBase: conjunction = Compel.parse_prompt_string(prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: # TODO: better logging for and syntax log_tokenization_for_conjunction(conjunction, tokenizer) @@ -282,7 +263,7 @@ class SDXLPromptInvocationBase: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -307,7 +288,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True ) @@ -364,14 +345,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation( @@ -379,7 +355,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): title="SDXL Refiner Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -397,7 +373,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) @@ -417,14 +393,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation_output("clip_skip_output") @@ -447,7 +418,7 @@ class ClipSkipInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) - def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: + def invoke(self, context) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers return ClipSkipInvocationOutput( clip=self.clip, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 9b652b8eee..3797722c93 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -25,18 +25,17 @@ from controlnet_aux.util import HWC3, ade_palette from PIL import Image from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput +from invokeai.app.invocations.baseinvocation import WithMetadata +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector +from invokeai.backend.model_management.models.base import BaseModelType -from ...backend.model_management import BaseModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -121,7 +120,7 @@ class ControlNetInvocation(BaseInvocation): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> ControlOutput: + def invoke(self, context) -> ControlOutput: return ControlOutput( control=ControlField( image=self.image, @@ -145,23 +144,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata): # superclass just passes through image without processing return image - def invoke(self, context: InvocationContext) -> ImageOutput: - raw_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + raw_image = context.images.get_pil(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) # currently can't see processed image in node UI without a showImage node, # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery - image_dto = context.services.images.create( - image=processed_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.CONTROL, - session_id=context.graph_execution_state_id, - node_id=self.id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=processed_image) """Builds an ImageOutput and its ImageField""" processed_image_field = ImageField(image_name=image_dto.image_name) @@ -180,7 +170,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata): title="Canny Processor", tags=["controlnet", "canny"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" @@ -203,7 +193,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation): title="HED (softedge) Processor", tags=["controlnet", "hed", "softedge"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" @@ -232,7 +222,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation): title="Lineart Processor", tags=["controlnet", "lineart"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" @@ -254,7 +244,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation): title="Lineart Anime Processor", tags=["controlnet", "lineart", "anime"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" @@ -277,7 +267,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): title="Midas Depth Processor", tags=["controlnet", "midas"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" @@ -304,7 +294,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): title="Normal BAE Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" @@ -321,7 +311,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): @invocation( - "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0" + "mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1" ) class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" @@ -344,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): @invocation( - "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0" + "pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1" ) class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" @@ -371,7 +361,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): title="Content Shuffle Processor", tags=["controlnet", "contentshuffle"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" @@ -401,7 +391,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): title="Zoe (Depth) Processor", tags=["controlnet", "zoe", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" @@ -417,7 +407,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): title="Mediapipe Face Processor", tags=["controlnet", "mediapipe", "face"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" @@ -440,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): title="Leres (Depth) Processor", tags=["controlnet", "leres", "depth"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" @@ -469,7 +459,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): title="Tile Resample Processor", tags=["controlnet", "tile"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class TileResamplerProcessorInvocation(ImageProcessorInvocation): """Tile resampler processor""" @@ -509,7 +499,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation): title="Segment Anything Processor", tags=["controlnet", "segmentanything"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" @@ -551,7 +541,7 @@ class SamDetectorReproducibleColors(SamDetector): title="Color Map Processor", tags=["controlnet"], category="controlnet", - version="1.2.0", + version="1.2.1", ) class ColorMapImageProcessorInvocation(ImageProcessorInvocation): """Generates a color map from the provided image""" diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 5865338e19..375b18f9c5 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -5,23 +5,23 @@ import cv2 as cv import numpy from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata -@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0") +@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1") class CvInpaintInvocation(BaseInvocation, WithMetadata): """Simple inpaint using opencv.""" image: ImageField = InputField(description="The image to inpaint") mask: ImageField = InputField(description="The mask to use when inpainting") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - mask = context.services.images.get_pil_image(self.mask.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + mask = context.images.get_pil(self.mask.image_name) # Convert to cv image/mask # TODO: consider making these utility functions @@ -35,18 +35,6 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata): # TODO: consider making a utility function image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) - image_dto = context.services.images.create( - image=image_inpainted, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_inpainted) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 13f1066ec3..2c92e28cfe 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -1,7 +1,7 @@ import math import re from pathlib import Path -from typing import Optional, TypedDict +from typing import TYPE_CHECKING, Optional, TypedDict import cv2 import numpy as np @@ -13,13 +13,16 @@ from pydantic import field_validator import invokeai.assets.fonts as font_assets from invokeai.app.invocations.baseinvocation import ( BaseInvocation, - InvocationContext, + WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory + +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext @invocation_output("face_mask_output") @@ -174,7 +177,7 @@ def prepare_faces_list( def generate_face_box_mask( - context: InvocationContext, + context: "InvocationContext", minimum_confidence: float, x_offset: float, y_offset: float, @@ -273,7 +276,7 @@ def generate_face_box_mask( def extract_face( - context: InvocationContext, + context: "InvocationContext", image: ImageType, face: FaceResultData, padding: int, @@ -304,37 +307,37 @@ def extract_face( # Adjust the crop boundaries to stay within the original image's dimensions if x_min < 0: - context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.") + context.logger.warning("FaceTools --> -X-axis padding reached image edge.") x_max -= x_min x_min = 0 elif x_max > mask.width: - context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.") + context.logger.warning("FaceTools --> +X-axis padding reached image edge.") x_min -= x_max - mask.width x_max = mask.width if y_min < 0: - context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> +Y-axis padding reached image edge.") y_max -= y_min y_min = 0 elif y_max > mask.height: - context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.") + context.logger.warning("FaceTools --> -Y-axis padding reached image edge.") y_min -= y_max - mask.height y_max = mask.height # Ensure the crop is square and adjust the boundaries if needed if x_max - x_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.") diff = crop_size - (x_max - x_min) x_min -= diff // 2 x_max += diff - diff // 2 if y_max - y_min != crop_size: - context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") + context.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.") diff = crop_size - (y_max - y_min) y_min -= diff // 2 y_max += diff - diff // 2 - context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") + context.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}") # Crop the output image to the specified size with the center of the face mesh as the center. mask = mask.crop((x_min, y_min, x_max, y_max)) @@ -354,7 +357,7 @@ def extract_face( def get_faces_list( - context: InvocationContext, + context: "InvocationContext", image: ImageType, should_chunk: bool, minimum_confidence: float, @@ -366,7 +369,7 @@ def get_faces_list( # Generate the face box mask and get the center of the face. if not should_chunk: - context.services.logger.info("FaceTools --> Attempting full image face detection.") + context.logger.info("FaceTools --> Attempting full image face detection.") result = generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -378,7 +381,7 @@ def get_faces_list( draw_mesh=draw_mesh, ) if should_chunk or len(result) == 0: - context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") + context.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).") width, height = image.size image_chunks = [] x_offsets = [] @@ -397,7 +400,7 @@ def get_faces_list( x_offsets.append(x) y_offsets.append(0) fx += increment - context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}") + context.logger.info(f"FaceTools --> Chunk starting at x = {x}") elif height > width: # Portrait - slice the image vertically fy = 0.0 @@ -409,10 +412,10 @@ def get_faces_list( x_offsets.append(0) y_offsets.append(y) fy += increment - context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}") + context.logger.info(f"FaceTools --> Chunk starting at y = {y}") for idx in range(len(image_chunks)): - context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") + context.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}") result = result + generate_face_box_mask( context=context, minimum_confidence=minimum_confidence, @@ -426,7 +429,7 @@ def get_faces_list( if len(result) == 0: # Give up - context.services.logger.warning( + context.logger.warning( "FaceTools --> No face detected in chunked input image. Passing through original image." ) @@ -435,7 +438,7 @@ def get_faces_list( return all_faces -@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0") +@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1") class FaceOffInvocation(BaseInvocation, WithMetadata): """Bound, extract, and mask a face from an image using MediaPipe detection""" @@ -456,7 +459,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]: + def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]: all_faces = get_faces_list( context=context, image=image, @@ -468,11 +471,11 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): ) if len(all_faces) == 0: - context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.") + context.logger.warning("FaceOff --> No faces detected. Passing through original image.") return None if self.face_id > len(all_faces) - 1: - context.services.logger.warning( + context.logger.warning( f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image." ) return None @@ -483,8 +486,8 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): return face_data - def invoke(self, context: InvocationContext) -> FaceOffOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> FaceOffOutput: + image = context.images.get_pil(self.image.image_name) result = self.faceoff(context=context, image=image) if result is None: @@ -498,24 +501,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): x = result["x_min"] y = result["y_min"] - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - mask_dto = context.services.images.create( - image=result_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result_mask, image_category=ImageCategory.MASK) output = FaceOffOutput( image=ImageField(image_name=image_dto.image_name), @@ -529,7 +517,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): return output -@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0") +@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1") class FaceMaskInvocation(BaseInvocation, WithMetadata): """Face mask creation using mediapipe face detection""" @@ -556,7 +544,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata): raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")') return v - def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult: + def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult: all_faces = get_faces_list( context=context, image=image, @@ -578,7 +566,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata): if len(intersected_face_ids) == 0: id_range_str = ",".join([str(id) for id in id_range]) - context.services.logger.warning( + context.logger.warning( f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image." ) return FaceMaskResult( @@ -613,28 +601,13 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata): mask=mask_pil, ) - def invoke(self, context: InvocationContext) -> FaceMaskOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> FaceMaskOutput: + image = context.images.get_pil(self.image.image_name) result = self.facemask(context=context, image=image) - image_dto = context.services.images.create( - image=result["image"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result["image"]) - mask_dto = context.services.images.create( - image=result["mask"], - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + mask_dto = context.images.save(image=result["mask"], image_category=ImageCategory.MASK) output = FaceMaskOutput( image=ImageField(image_name=image_dto.image_name), @@ -647,7 +620,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata): @invocation( - "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0" + "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1" ) class FaceIdentifierInvocation(BaseInvocation, WithMetadata): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" @@ -661,7 +634,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType: + def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType: image = image.copy() all_faces = get_faces_list( @@ -702,22 +675,10 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): return image - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) result_image = self.faceidentifier(context=context, image=image) - image_dto = context.services.images.create( - image=result_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - workflow=context.workflow, - ) + image_dto = context.images.save(image=result_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 0cce8e3c6b..566babbb6b 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -1,11 +1,13 @@ +from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional, Tuple from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter from pydantic.fields import _Unset from pydantic_core import PydanticUndefined from invokeai.app.util.metaenum import MetaEnum +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -255,6 +257,10 @@ class InputFieldJSONSchemaExtra(BaseModel): class WithMetadata(BaseModel): + """ + Inherit from this class if your node needs a metadata input field. + """ + metadata: Optional[MetadataField] = Field( default=None, description=FieldDescriptions.metadata, @@ -498,4 +504,53 @@ def OutputField( field_kind=FieldKind.Output, ).model_dump(exclude_none=True), ) + + +class ImageField(BaseModel): + """An image primitive field""" + + image_name: str = Field(description="The name of the image") + + +class BoardField(BaseModel): + """A board primitive field""" + + board_id: str = Field(description="The id of the board") + + +class DenoiseMaskField(BaseModel): + """An inpaint mask field""" + + mask_name: str = Field(description="The name of the mask image") + masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") + + +class LatentsField(BaseModel): + """A latents tensor primitive field""" + + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") + + +class ColorField(BaseModel): + """A color primitive field""" + + r: int = Field(ge=0, le=255, description="The red component") + g: int = Field(ge=0, le=255, description="The green component") + b: int = Field(ge=0, le=255, description="The blue component") + a: int = Field(ge=0, le=255, description="The alpha component") + + def tuple(self) -> Tuple[int, int, int, int]: + return (self.r, self.g, self.b, self.a) + + +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + + +class ConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") # endregion diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 16d0f33dda..10ebd97ace 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,30 +7,36 @@ import cv2 import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, WithMetadata -from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin +from invokeai.app.invocations.baseinvocation import WithMetadata +from invokeai.app.invocations.fields import ( + BoardField, + ColorField, + FieldDescriptions, + ImageField, + Input, + InputField, +) +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker from .baseinvocation import ( BaseInvocation, Classification, - InvocationContext, invocation, ) -@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0") +@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1") class ShowImageInvocation(BaseInvocation): """Displays a provided image using the OS image viewer, and passes it forward in the pipeline.""" image: ImageField = InputField(description="The image to show") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - if image: - image.show() + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + image.show() # TODO: how to handle failure? @@ -46,7 +52,7 @@ class ShowImageInvocation(BaseInvocation): title="Blank Image", tags=["image"], category="image", - version="1.2.0", + version="1.2.1", ) class BlankImageInvocation(BaseInvocation, WithMetadata): """Creates a blank image and forwards it to the pipeline""" @@ -56,25 +62,12 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -82,7 +75,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): title="Crop Image", tags=["image", "crop"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageCropInvocation(BaseInvocation, WithMetadata): """Crops an image to a specified box. The box can be outside of the image.""" @@ -93,28 +86,15 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) image_crop.paste(image, (-self.x, -self.y)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -145,8 +125,8 @@ class CenterPadCropInvocation(BaseInvocation): description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) # Calculate and create new image dimensions new_width = image.width + self.right + self.left @@ -156,20 +136,9 @@ class CenterPadCropInvocation(BaseInvocation): # Paste new image onto input image_crop.paste(image, (self.left, self.top)) - image_dto = context.services.images.create( - image=image_crop, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - ) + image_dto = context.images.save(image=image_crop) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -177,7 +146,7 @@ class CenterPadCropInvocation(BaseInvocation): title="Paste Image", tags=["image", "paste"], category="image", - version="1.2.0", + version="1.2.1", ) class ImagePasteInvocation(BaseInvocation, WithMetadata): """Pastes an image into another image.""" @@ -192,12 +161,12 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): y: int = InputField(default=0, description="The top y coordinate at which to paste the image") crop: bool = InputField(default=False, description="Crop to base image dimensions") - def invoke(self, context: InvocationContext) -> ImageOutput: - base_image = context.services.images.get_pil_image(self.base_image.image_name) - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + base_image = context.images.get_pil(self.base_image.image_name) + image = context.images.get_pil(self.image.image_name) mask = None if self.mask is not None: - mask = context.services.images.get_pil_image(self.mask.image_name) + mask = context.images.get_pil(self.mask.image_name) mask = ImageOps.invert(mask.convert("L")) # TODO: probably shouldn't invert mask here... should user be required to do it? @@ -214,22 +183,9 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): base_w, base_h = base_image.size new_image = new_image.crop((abs(min_x), abs(min_y), abs(min_x) + base_w, abs(min_y) + base_h)) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -237,7 +193,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): title="Mask from Alpha", tags=["image", "mask"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): """Extracts the alpha channel of an image as a mask.""" @@ -245,29 +201,16 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to create the mask from") invert: bool = InputField(default=False, description="Whether or not to invert the mask") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_mask = image.split()[-1] if self.invert: image_mask = ImageOps.invert(image_mask) - image_dto = context.services.images.create( - image=image_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -275,7 +218,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): title="Multiply Images", tags=["image", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageMultiplyInvocation(BaseInvocation, WithMetadata): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" @@ -283,28 +226,15 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata): image1: ImageField = InputField(description="The first image to multiply") image2: ImageField = InputField(description="The second image to multiply") - def invoke(self, context: InvocationContext) -> ImageOutput: - image1 = context.services.images.get_pil_image(self.image1.image_name) - image2 = context.services.images.get_pil_image(self.image2.image_name) + def invoke(self, context) -> ImageOutput: + image1 = context.images.get_pil(self.image1.image_name) + image2 = context.images.get_pil(self.image2.image_name) multiply_image = ImageChops.multiply(image1, image2) - image_dto = context.services.images.create( - image=multiply_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=multiply_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_CHANNELS = Literal["A", "R", "G", "B"] @@ -315,7 +245,7 @@ IMAGE_CHANNELS = Literal["A", "R", "G", "B"] title="Extract Image Channel", tags=["image", "channel"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelInvocation(BaseInvocation, WithMetadata): """Gets a channel from an image.""" @@ -323,27 +253,14 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to get the channel from") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) channel_image = image.getchannel(self.channel) - image_dto = context.services.images.create( - image=channel_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=channel_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] @@ -354,7 +271,7 @@ IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F title="Convert Image Mode", tags=["image", "convert"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageConvertInvocation(BaseInvocation, WithMetadata): """Converts an image to a different mode.""" @@ -362,27 +279,14 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to convert") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) converted_image = image.convert(self.mode) - image_dto = context.services.images.create( - image=converted_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=converted_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -390,7 +294,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata): title="Blur Image", tags=["image", "blur"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageBlurInvocation(BaseInvocation, WithMetadata): """Blurs an image""" @@ -400,30 +304,17 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): # Metadata blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) blur = ( ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius) ) blur_image = image.filter(blur) - image_dto = context.services.images.create( - image=blur_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=blur_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -431,7 +322,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): title="Unsharp Mask", tags=["image", "unsharp_mask"], category="image", - version="1.2.0", + version="1.2.1", classification=Classification.Beta, ) class UnsharpMaskInvocation(BaseInvocation, WithMetadata): @@ -447,8 +338,8 @@ class UnsharpMaskInvocation(BaseInvocation, WithMetadata): def array_from_pil(self, img): return numpy.array(img) / 255 - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) mode = image.mode alpha_channel = image.getchannel("A") if mode == "RGBA" else None @@ -466,16 +357,7 @@ class UnsharpMaskInvocation(BaseInvocation, WithMetadata): if alpha_channel is not None: image.putalpha(alpha_channel) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) return ImageOutput( image=ImageField(image_name=image_dto.image_name), @@ -509,7 +391,7 @@ PIL_RESAMPLING_MAP = { title="Resize Image", tags=["image", "resize"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageResizeInvocation(BaseInvocation, WithMetadata): """Resizes an image to specific dimensions""" @@ -519,8 +401,8 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): height: int = InputField(default=512, gt=0, description="The height to resize to (px)") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -529,22 +411,9 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -552,7 +421,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): title="Scale Image", tags=["image", "scale"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageScaleInvocation(BaseInvocation, WithMetadata): """Scales an image by a factor""" @@ -565,8 +434,8 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): ) resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width * self.scale_factor) @@ -577,22 +446,9 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): resample=resample_mode, ) - image_dto = context.services.images.create( - image=resize_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=resize_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -600,7 +456,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): title="Lerp Image", tags=["image", "lerp"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageLerpInvocation(BaseInvocation, WithMetadata): """Linear interpolation of all pixels of an image""" @@ -609,30 +465,17 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = image_arr * (self.max - self.min) + self.min lerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=lerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=lerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -640,7 +483,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): """Inverse linear interpolation of all pixels of an image""" @@ -649,30 +492,17 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 # type: ignore [assignment] ilerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_dto = context.services.images.create( - image=ilerp_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=ilerp_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -680,17 +510,17 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): title="Blur NSFW Image", tags=["image", "nsfw"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): """Add blur to NSFW-flagged images""" image: ImageField = InputField(description="The image to check") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) - logger = context.services.logger + logger = context.logger logger.debug("Running NSFW checker") if SafetyChecker.has_nsfw_concept(image): logger.info("A potentially NSFW image has been detected. Image will be blurred.") @@ -699,22 +529,9 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): blurry_image.paste(caution, (0, 0), caution) image = blurry_image - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) def _get_caution_img(self) -> Image.Image: import invokeai.app.assets.images as image_assets @@ -728,7 +545,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): title="Add Invisible Watermark", tags=["image", "watermark"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageWatermarkInvocation(BaseInvocation, WithMetadata): """Add an invisible watermark to an image""" @@ -736,25 +553,12 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") text: str = InputField(default="InvokeAI", description="Watermark text") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) new_image = InvisibleWatermark.add_watermark(image, self.text) - image_dto = context.services.images.create( - image=new_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -762,7 +566,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata): title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskEdgeInvocation(BaseInvocation, WithMetadata): """Applies an edge mask to an image""" @@ -775,8 +579,8 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): description="Second threshold for the hysteresis procedure in Canny edge detection" ) - def invoke(self, context: InvocationContext) -> ImageOutput: - mask = context.services.images.get_pil_image(self.image.image_name).convert("L") + def invoke(self, context) -> ImageOutput: + mask = context.images.get_pil(self.image.image_name).convert("L") npimg = numpy.asarray(mask, dtype=numpy.uint8) npgradient = numpy.uint8(255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0))) @@ -791,22 +595,9 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): new_mask = ImageOps.invert(new_mask) - image_dto = context.services.images.create( - image=new_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.MASK, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=new_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -814,7 +605,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): title="Combine Masks", tags=["image", "mask", "multiply"], category="image", - version="1.2.0", + version="1.2.1", ) class MaskCombineInvocation(BaseInvocation, WithMetadata): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" @@ -822,28 +613,15 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata): mask1: ImageField = InputField(description="The first mask to combine") mask2: ImageField = InputField(description="The second image to combine") - def invoke(self, context: InvocationContext) -> ImageOutput: - mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") - mask2 = context.services.images.get_pil_image(self.mask2.image_name).convert("L") + def invoke(self, context) -> ImageOutput: + mask1 = context.images.get_pil(self.mask1.image_name).convert("L") + mask2 = context.images.get_pil(self.mask2.image_name).convert("L") combined_mask = ImageChops.multiply(mask1, mask2) - image_dto = context.services.images.create( - image=combined_mask, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=combined_mask, image_category=ImageCategory.MASK) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -851,7 +629,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata): title="Color Correct", tags=["image", "color"], category="image", - version="1.2.0", + version="1.2.1", ) class ColorCorrectInvocation(BaseInvocation, WithMetadata): """ @@ -864,14 +642,14 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") mask_blur_radius: float = InputField(default=8, description="Mask blur radius") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: pil_init_mask = None if self.mask is not None: - pil_init_mask = context.services.images.get_pil_image(self.mask.image_name).convert("L") + pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L") - init_image = context.services.images.get_pil_image(self.reference.image_name) + init_image = context.images.get_pil(self.reference.image_name) - result = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + result = context.images.get_pil(self.image.image_name).convert("RGBA") # if init_image is None or init_mask is None: # return result @@ -945,22 +723,9 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): # Paste original on color-corrected generation (using blurred mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) - image_dto = context.services.images.create( - image=matched_result, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=matched_result) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -968,7 +733,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): title="Adjust Image Hue", tags=["image", "hue"], category="image", - version="1.2.0", + version="1.2.1", ) class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): """Adjusts the Hue of an image.""" @@ -976,8 +741,8 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to adjust") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # Convert image to HSV color space hsv_image = numpy.array(pil_image.convert("HSV")) @@ -991,24 +756,9 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): # Convert back to PIL format and to original color mode pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) COLOR_CHANNELS = Literal[ @@ -1072,7 +822,7 @@ CHANNEL_FORMATS = { "value", ], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): """Add or subtract a value from a specific color channel of an image.""" @@ -1081,8 +831,8 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): channel: COLOR_CHANNELS = InputField(description="Which channel to adjust") offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1101,24 +851,9 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1143,7 +878,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): "value", ], category="image", - version="1.2.0", + version="1.2.1", ) class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): """Scale a specific color channel of an image.""" @@ -1153,8 +888,8 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.") invert_channel: bool = InputField(default=False, description="Invert the channel after scaling") - def invoke(self, context: InvocationContext) -> ImageOutput: - pil_image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple mode = CHANNEL_FORMATS[self.channel][0] @@ -1177,24 +912,9 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): # Convert back to RGBA format and output pil_image = Image.fromarray(converted_image.astype(numpy.uint8), mode=mode).convert("RGBA") - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - is_intermediate=self.is_intermediate, - session_id=context.graph_execution_state_id, - workflow=context.workflow, - metadata=self.metadata, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - ), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1202,7 +922,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): title="Save Image", tags=["primitives", "image"], category="primitives", - version="1.2.0", + version="1.2.1", use_cache=False, ) class SaveImageInvocation(BaseInvocation, WithMetadata): @@ -1211,26 +931,12 @@ class SaveImageInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - board_id=self.board.board_id if self.board else None, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( @@ -1238,7 +944,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata): title="Linear UI Image Output", tags=["primitives", "image"], category="primitives", - version="1.0.1", + version="1.0.2", use_cache=False, ) class LinearUIOutputInvocation(BaseInvocation, WithMetadata): @@ -1247,19 +953,13 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context: InvocationContext) -> ImageOutput: - image_dto = context.services.images.get_dto(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image_dto = context.images.get_dto(self.image.image_name) - if self.board: - context.services.board_images.add_image_to_board(self.board.board_id, self.image.image_name) - - if image_dto.is_intermediate != self.is_intermediate: - context.services.images.update( - self.image.image_name, changes=ImageRecordChanges(is_intermediate=self.is_intermediate) - ) - - return ImageOutput( - image=ImageField(image_name=self.image.image_name), - width=image_dto.width, - height=image_dto.height, + image_dto = context.images.update( + image_name=self.image.image_name, + board_id=self.board.board_id if self.board else None, + is_intermediate=self.is_intermediate, ) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index d4d3d5bea4..be51c8312f 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -6,15 +6,15 @@ from typing import Literal, Optional, get_args import numpy as np from PIL import Image, ImageOps -from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ColorField, ImageField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch -from .baseinvocation import BaseInvocation, InvocationContext, invocation -from .fields import InputField, WithMetadata +from .baseinvocation import BaseInvocation, WithMetadata, invocation +from .fields import InputField from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES @@ -119,7 +119,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] return si -@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class InfillColorInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image with a solid color""" @@ -129,33 +129,20 @@ class InfillColorInvocation(BaseInvocation, WithMetadata): description="The color to use to infill", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") +@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") class InfillTileInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image with tiles of the image""" @@ -168,32 +155,19 @@ class InfillTileInvocation(BaseInvocation, WithMetadata): description="The seed to use for tile generation (omit for random)", ) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) infilled.paste(image, (0, 0), image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) @invocation( - "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0" + "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1" ) class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using the PatchMatch algorithm""" @@ -202,8 +176,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA") + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name).convert("RGBA") resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -228,77 +202,38 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): infilled.paste(image, (0, 0), mask=image.split()[-1]) # image.paste(infilled, (0, 0), mask=image.split()[-1]) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class LaMaInfillInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using the LaMa model""" image: ImageField = InputField(description="The image to infill") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = infill_lama(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) -@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0") +@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1") class CV2InfillInvocation(BaseInvocation, WithMetadata): """Infills transparent areas of an image using OpenCV Inpainting""" image: ImageField = InputField(description="The image to infill") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) infilled = infill_cv2(image.copy()) - image_dto = context.services.images.create( - image=infilled, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=infilled) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index c01e0ed0fb..b836be04b5 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -62,7 +61,7 @@ class IPAdapterOutput(BaseInvocationOutput): ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" @@ -93,9 +92,9 @@ class IPAdapterInvocation(BaseInvocation): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> IPAdapterOutput: + def invoke(self, context) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. - ip_adapter_info = context.services.model_manager.model_info( + ip_adapter_info = context.models.get_info( self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter ) # HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model @@ -104,7 +103,7 @@ class IPAdapterInvocation(BaseInvocation): # is currently messy due to differences between how the model info is generated when installing a model from # disk vs. downloading the model. image_encoder_model_id = get_ip_adapter_image_encoder_model_id( - os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"]) + os.path.join(context.config.get().models_path, ip_adapter_info["path"]) ) image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model = CLIPVisionModelField( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 909c307481..0127a6521e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,7 +3,7 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import List, Literal, Optional, Union +from typing import TYPE_CHECKING, List, Literal, Optional, Union import einops import numpy as np @@ -23,21 +23,26 @@ from diffusers.schedulers import SchedulerMixin as Scheduler from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, WithMetadata +from invokeai.app.invocations.fields import ( + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + LatentsField, + OutputField, + UIType, + WithMetadata, +) from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.primitives import ( - DenoiseMaskField, DenoiseMaskOutput, - ImageField, ImageOutput, - LatentsField, LatentsOutput, - build_latents_output, ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo @@ -59,14 +64,15 @@ from ...backend.util.devices import choose_precision, choose_torch_device from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) -from .compel import ConditioningField from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField +if TYPE_CHECKING: + from invokeai.app.services.shared.invocation_context import InvocationContext + if choose_torch_device() == torch.device("mps"): from torch import mps @@ -102,7 +108,7 @@ class SchedulerInvocation(BaseInvocation): ui_type=UIType.Scheduler, ) - def invoke(self, context: InvocationContext) -> SchedulerOutput: + def invoke(self, context) -> SchedulerOutput: return SchedulerOutput(scheduler=self.scheduler) @@ -111,7 +117,7 @@ class SchedulerInvocation(BaseInvocation): title="Create Denoise Mask", tags=["mask", "denoise"], category="latents", - version="1.0.0", + version="1.0.1", ) class CreateDenoiseMaskInvocation(BaseInvocation): """Creates mask for denoising model run.""" @@ -137,9 +143,9 @@ class CreateDenoiseMaskInvocation(BaseInvocation): return mask_tensor @torch.no_grad() - def invoke(self, context: InvocationContext) -> DenoiseMaskOutput: + def invoke(self, context) -> DenoiseMaskOutput: if self.image is not None: - image = context.services.images.get_pil_image(self.image.image_name) + image = context.images.get_pil(self.image.image_name) image = image_resized_to_grid_as_tensor(image.convert("RGB")) if image.dim() == 3: image = image.unsqueeze(0) @@ -147,47 +153,37 @@ class CreateDenoiseMaskInvocation(BaseInvocation): image = None mask = self.prep_mask_tensor( - context.services.images.get_pil_image(self.mask.image_name), + context.images.get_pil(self.mask.image_name), ) if image is not None: - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0) # TODO: masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone()) - masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents" - context.services.latents.save(masked_latents_name, masked_latents) + masked_latents_name = context.latents.save(tensor=masked_latents) else: masked_latents_name = None - mask_name = f"{context.graph_execution_state_id}__{self.id}_mask" - context.services.latents.save(mask_name, mask) + mask_name = context.latents.save(tensor=mask) - return DenoiseMaskOutput( - denoise_mask=DenoiseMaskField( - mask_name=mask_name, - masked_latents_name=masked_latents_name, - ), + return DenoiseMaskOutput.build( + mask_name=mask_name, + masked_latents_name=masked_latents_name, ) def get_scheduler( - context: InvocationContext, + context: "InvocationContext", scheduler_info: ModelInfo, scheduler_name: str, seed: int, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) - orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.model_dump(), - context=context, - ) + orig_scheduler_info = context.models.load(**scheduler_info.model_dump()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -216,7 +212,7 @@ def get_scheduler( title="Denoise Latents", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], category="latents", - version="1.5.1", + version="1.5.2", ) class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" @@ -302,34 +298,18 @@ class DenoiseLatentsInvocation(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v - # TODO: pass this an emitter method or something? or a session for dispatching? - def dispatch_progress( - self, - context: InvocationContext, - source_node_id: str, - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - base_model=base_model, - ) - def get_conditioning_data( self, - context: InvocationContext, + context: "InvocationContext", scheduler, unet, seed, ) -> ConditioningData: - positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) + positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) extra_conditioning_info = c.extra_conditioning - negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) + negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( @@ -389,7 +369,7 @@ class DenoiseLatentsInvocation(BaseInvocation): def prep_control_data( self, - context: InvocationContext, + context: "InvocationContext", control_input: Union[ControlField, List[ControlField]], latents_shape: List[int], exit_stack: ExitStack, @@ -417,17 +397,16 @@ class DenoiseLatentsInvocation(BaseInvocation): controlnet_data = [] for control_info in control_list: control_model = exit_stack.enter_context( - context.services.model_manager.get_model( + context.models.load( model_name=control_info.control_model.model_name, model_type=ModelType.ControlNet, base_model=control_info.control_model.base_model, - context=context, ) ) # control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.images.get_pil(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -463,7 +442,7 @@ class DenoiseLatentsInvocation(BaseInvocation): def prep_ip_adapter_data( self, - context: InvocationContext, + context: "InvocationContext", ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], conditioning_data: ConditioningData, exit_stack: ExitStack, @@ -485,19 +464,17 @@ class DenoiseLatentsInvocation(BaseInvocation): conditioning_data.ip_adapter_conditioning = [] for single_ip_adapter in ip_adapter: ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( - context.services.model_manager.get_model( + context.models.load( model_name=single_ip_adapter.ip_adapter_model.model_name, model_type=ModelType.IPAdapter, base_model=single_ip_adapter.ip_adapter_model.base_model, - context=context, ) ) - image_encoder_model_info = context.services.model_manager.get_model( + image_encoder_model_info = context.models.load( model_name=single_ip_adapter.image_encoder_model.model_name, model_type=ModelType.CLIPVision, base_model=single_ip_adapter.image_encoder_model.base_model, - context=context, ) # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. @@ -505,7 +482,7 @@ class DenoiseLatentsInvocation(BaseInvocation): if not isinstance(single_ipa_images, list): single_ipa_images = [single_ipa_images] - single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images] + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images] # TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other # models are needed in memory. This would help to reduce peak memory utilization in low-memory environments. @@ -532,7 +509,7 @@ class DenoiseLatentsInvocation(BaseInvocation): def run_t2i_adapters( self, - context: InvocationContext, + context: "InvocationContext", t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], latents_shape: list[int], do_classifier_free_guidance: bool, @@ -549,13 +526,12 @@ class DenoiseLatentsInvocation(BaseInvocation): t2i_adapter_data = [] for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_info = context.services.model_manager.get_model( + t2i_adapter_model_info = context.models.load( model_name=t2i_adapter_field.t2i_adapter_model.model_name, model_type=ModelType.T2IAdapter, base_model=t2i_adapter_field.t2i_adapter_model.base_model, - context=context, ) - image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name) + image = context.images.get_pil(t2i_adapter_field.image.image_name) # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1: @@ -642,30 +618,30 @@ class DenoiseLatentsInvocation(BaseInvocation): return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context, latents): + def prep_inpaint_mask(self, context: "InvocationContext", latents): if self.denoise_mask is None: return None, None - mask = context.services.latents.get(self.denoise_mask.mask_name) + mask = context.latents.get(self.denoise_mask.mask_name) mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False) if self.denoise_mask.masked_latents_name is not None: - masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name) + masked_latents = context.latents.get(self.denoise_mask.masked_latents_name) else: masked_latents = None return 1 - mask, masked_latents @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: + def invoke(self, context) -> LatentsOutput: with SilenceWarnings(): # this quenches NSFW nag from diffusers seed = None noise = None if self.noise is not None: - noise = context.services.latents.get(self.noise.latents_name) + noise = context.latents.get(self.noise.latents_name) seed = self.noise.seed if self.latents is not None: - latents = context.services.latents.get(self.latents.latents_name) + latents = context.latents.get(self.latents.latents_name) if seed is None: seed = self.latents.seed @@ -691,27 +667,17 @@ class DenoiseLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - def step_callback(state: PipelineIntermediateState): - self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model) + context.util.sd_step_callback(state, self.unet.unet.base_model) def _lora_loader(): for lora in self.unet.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), - context=context, - ) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) yield (lora_info.context.model, lora.weight) del lora_info return - unet_info = context.services.model_manager.get_model( - **self.unet.unet.model_dump(), - context=context, - ) + unet_info = context.models.load(**self.unet.unet.model_dump()) with ( ExitStack() as exit_stack, ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config), @@ -787,9 +753,8 @@ class DenoiseLatentsInvocation(BaseInvocation): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, result_latents) - return build_latents_output(latents_name=name, latents=result_latents, seed=seed) + name = context.latents.save(tensor=result_latents) + return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed) @invocation( @@ -797,7 +762,7 @@ class DenoiseLatentsInvocation(BaseInvocation): title="Latents to Image", tags=["latents", "image", "vae", "l2i"], category="latents", - version="1.2.0", + version="1.2.1", ) class LatentsToImageInvocation(BaseInvocation, WithMetadata): """Generates an image from latents.""" @@ -814,13 +779,10 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> ImageOutput: + latents = context.latents.get(self.latents.latents_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae: latents = latents.to(vae.device) @@ -849,7 +811,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): vae.to(dtype=torch.float16) latents = latents.half() - if self.tiled or context.services.configuration.tiled_decode: + if self.tiled or context.config.get().tiled_decode: vae.enable_tiling() else: vae.disable_tiling() @@ -873,22 +835,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] @@ -899,7 +848,7 @@ LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", title="Resize Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ResizeLatentsInvocation(BaseInvocation): """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" @@ -921,8 +870,8 @@ class ResizeLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -940,10 +889,8 @@ class ResizeLatentsInvocation(BaseInvocation): if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.latents.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -951,7 +898,7 @@ class ResizeLatentsInvocation(BaseInvocation): title="Scale Latents", tags=["latents", "resize"], category="latents", - version="1.0.0", + version="1.0.1", ) class ScaleLatentsInvocation(BaseInvocation): """Scales latents by a given factor.""" @@ -964,8 +911,8 @@ class ScaleLatentsInvocation(BaseInvocation): mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) # TODO: device = choose_torch_device() @@ -984,10 +931,8 @@ class ScaleLatentsInvocation(BaseInvocation): if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, resized_latents) - return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) + name = context.latents.save(tensor=resized_latents) + return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @invocation( @@ -995,7 +940,7 @@ class ScaleLatentsInvocation(BaseInvocation): title="Image to Latents", tags=["latents", "image", "vae", "i2l"], category="latents", - version="1.0.0", + version="1.0.1", ) class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" @@ -1055,13 +1000,10 @@ class ImageToLatentsInvocation(BaseInvocation): return latents @torch.no_grad() - def invoke(self, context: InvocationContext) -> LatentsOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> LatentsOutput: + image = context.images.get_pil(self.image.image_name) - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - context=context, - ) + vae_info = context.models.load(**self.vae.vae.model_dump()) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) if image_tensor.dim() == 3: @@ -1069,10 +1011,9 @@ class ImageToLatentsInvocation(BaseInvocation): latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor) - name = f"{context.graph_execution_state_id}__{self.id}" latents = latents.to("cpu") - context.services.latents.save(name, latents) - return build_latents_output(latents_name=name, latents=latents, seed=None) + name = context.latents.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) @singledispatchmethod @staticmethod @@ -1092,7 +1033,7 @@ class ImageToLatentsInvocation(BaseInvocation): title="Blend Latents", tags=["latents", "blend"], category="latents", - version="1.0.0", + version="1.0.1", ) class BlendLatentsInvocation(BaseInvocation): """Blend two latents using a given alpha. Latents must have same size.""" @@ -1107,9 +1048,9 @@ class BlendLatentsInvocation(BaseInvocation): ) alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents_a = context.services.latents.get(self.latents_a.latents_name) - latents_b = context.services.latents.get(self.latents_b.latents_name) + def invoke(self, context) -> LatentsOutput: + latents_a = context.latents.get(self.latents_a.latents_name) + latents_b = context.latents.get(self.latents_b.latents_name) if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") @@ -1163,10 +1104,8 @@ class BlendLatentsInvocation(BaseInvocation): if device == torch.device("mps"): mps.empty_cache() - name = f"{context.graph_execution_state_id}__{self.id}" - # context.services.latents.set(name, resized_latents) - context.services.latents.save(name, blended_latents) - return build_latents_output(latents_name=name, latents=blended_latents) + name = context.latents.save(tensor=blended_latents) + return LatentsOutput.build(latents_name=name, latents=blended_latents) # The Crop Latents node was copied from @skunkworxdark's implementation here: @@ -1176,7 +1115,7 @@ class BlendLatentsInvocation(BaseInvocation): title="Crop Latents", tags=["latents", "crop"], category="latents", - version="1.0.0", + version="1.0.1", ) # TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`. # Currently, if the class names conflict then 'GET /openapi.json' fails. @@ -1210,8 +1149,8 @@ class CropLatentsCoreInvocation(BaseInvocation): description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.", ) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) x1 = self.x // LATENT_SCALE_FACTOR y1 = self.y // LATENT_SCALE_FACTOR @@ -1220,10 +1159,9 @@ class CropLatentsCoreInvocation(BaseInvocation): cropped_latents = latents[..., y1:y2, x1:x2] - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, cropped_latents) + name = context.latents.save(tensor=cropped_latents) - return build_latents_output(latents_name=name, latents=cropped_latents) + return LatentsOutput.build(latents_name=name, latents=cropped_latents) @invocation_output("ideal_size_output") diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 6ca53011f0..d2dbf04981 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -8,7 +8,7 @@ from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation @invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0") @@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a + self.b) @@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a - self.b) @@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.a * self.b) @@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=int(self.a / self.b)) @@ -69,7 +69,7 @@ class RandomIntInvocation(BaseInvocation): low: int = InputField(default=0, description=FieldDescriptions.inclusive_low) high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=np.random.randint(self.low, self.high)) @@ -88,7 +88,7 @@ class RandomFloatInvocation(BaseInvocation): high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high) decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places) - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: random_float = np.random.uniform(self.low, self.high) rounded_float = round(random_float, self.decimals) return FloatOutput(value=rounded_float) @@ -110,7 +110,7 @@ class FloatToIntegerInvocation(BaseInvocation): default="Nearest", description="The method to use for rounding" ) - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: if self.method == "Nearest": return IntegerOutput(value=round(self.value / self.multiple) * self.multiple) elif self.method == "Floor": @@ -128,7 +128,7 @@ class RoundInvocation(BaseInvocation): value: float = InputField(default=0, description="The float value") decimals: int = InputField(default=0, description="The number of decimal places") - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: return FloatOutput(value=round(self.value, self.decimals)) @@ -196,7 +196,7 @@ class IntegerMathInvocation(BaseInvocation): raise ValueError("Result of exponentiation is not an integer") return v - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return IntegerOutput(value=self.a + self.b) @@ -270,7 +270,7 @@ class FloatMathInvocation(BaseInvocation): raise ValueError("Root operation resulted in a complex number") return v - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return FloatOutput(value=self.a + self.b) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 399e217dc1..9d74abd8c1 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -5,15 +5,20 @@ from pydantic import BaseModel, ConfigDict, Field from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import ControlField -from invokeai.app.invocations.fields import FieldDescriptions, InputField, MetadataField, OutputField, UIType +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + InputField, + MetadataField, + OutputField, + UIType, +) from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField -from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.t2i_adapter import T2IAdapterField from ...version import __version__ @@ -59,7 +64,7 @@ class MetadataItemInvocation(BaseInvocation): label: str = InputField(description=FieldDescriptions.metadata_item_label) value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any) - def invoke(self, context: InvocationContext) -> MetadataItemOutput: + def invoke(self, context) -> MetadataItemOutput: return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value)) @@ -76,7 +81,7 @@ class MetadataInvocation(BaseInvocation): description=FieldDescriptions.metadata_item_polymorphic ) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: if isinstance(self.items, MetadataItemField): # single metadata item data = {self.items.label: self.items.value} @@ -95,7 +100,7 @@ class MergeMetadataInvocation(BaseInvocation): collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: data = {} for item in self.collection: data.update(item.model_dump()) @@ -213,7 +218,7 @@ class CoreMetadataInvocation(BaseInvocation): description="The start value used for refiner denoising", ) - def invoke(self, context: InvocationContext) -> MetadataOutput: + def invoke(self, context) -> MetadataOutput: """Collects and outputs a CoreMetadata object""" return MetadataOutput( diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c710c9761b..f81e559e44 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -10,7 +10,6 @@ from ...backend.model_management import BaseModelType, ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -102,7 +101,7 @@ class LoRAModelField(BaseModel): title="Main Model", tags=["model"], category="model", - version="1.0.0", + version="1.0.1", ) class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -110,13 +109,13 @@ class MainModelLoaderInvocation(BaseInvocation): model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - def invoke(self, context: InvocationContext) -> ModelLoaderOutput: + def invoke(self, context) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -203,7 +202,7 @@ class LoraLoaderOutput(BaseInvocationOutput): clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0") +@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1") class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -222,14 +221,14 @@ class LoraLoaderInvocation(BaseInvocation): title="CLIP", ) - def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + def invoke(self, context) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") base_model = self.lora.base_model lora_name = self.lora.model_name - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=lora_name, model_type=ModelType.Lora, @@ -285,7 +284,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput): title="SDXL LoRA", tags=["lora", "model"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" @@ -311,14 +310,14 @@ class SDXLLoraLoaderInvocation(BaseInvocation): title="CLIP 2", ) - def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: + def invoke(self, context) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") base_model = self.lora.base_model lora_name = self.lora.model_name - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=lora_name, model_type=ModelType.Lora, @@ -384,7 +383,7 @@ class VAEModelField(BaseModel): model_config = ConfigDict(protected_namespaces=()) -@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0") +@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" @@ -394,12 +393,12 @@ class VaeLoaderInvocation(BaseInvocation): title="VAE", ) - def invoke(self, context: InvocationContext) -> VAEOutput: + def invoke(self, context) -> VAEOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae - if not context.services.model_manager.model_exists( + if not context.models.exists( base_model=base_model, model_name=model_name, model_type=model_type, @@ -449,7 +448,7 @@ class SeamlessModeInvocation(BaseInvocation): seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") - def invoke(self, context: InvocationContext) -> SeamlessModeOutput: + def invoke(self, context) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) @@ -485,6 +484,6 @@ class FreeUInvocation(BaseInvocation): s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1) s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2) - def invoke(self, context: InvocationContext) -> UNetOutput: + def invoke(self, context) -> UNetOutput: self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2) return UNetOutput(unet=self.unet) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 2e717ac561..41641152f0 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -4,15 +4,13 @@ import torch from pydantic import field_validator -from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField -from invokeai.app.invocations.latent import LatentsField +from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -67,13 +65,13 @@ class NoiseOutput(BaseInvocationOutput): width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) - -def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): - return NoiseOutput( - noise=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput": + return cls( + noise=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) @invocation( @@ -114,7 +112,7 @@ class NoiseInvocation(BaseInvocation): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) - def invoke(self, context: InvocationContext) -> NoiseOutput: + def invoke(self, context) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, @@ -122,6 +120,5 @@ class NoiseInvocation(BaseInvocation): seed=self.seed, use_cpu=self.use_cpu, ) - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, noise) - return build_noise_output(latents_name=name, latents=noise, seed=self.seed) + name = context.latents.save(tensor=noise) + return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index b43d7eaef2..3f8e6669ab 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -37,7 +37,7 @@ from .baseinvocation import ( invocation_output, ) from .controlnet_image_processors import ControlField -from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler +from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, get_scheduler from .model import ClipField, ModelInfo, UNetField, VaeField ORT_TO_NP_TYPE = { @@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation): prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.model_dump(), ) @@ -201,7 +201,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation): # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context: InvocationContext) -> LatentsOutput: + def invoke(self, context) -> LatentsOutput: c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): ) # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: @@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation): description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel ) - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: + def invoke(self, context) -> ONNXModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.ONNX diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index dab9c3dc0f..bf59e87d27 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -41,7 +41,7 @@ from matplotlib.ticker import MaxNLocator from invokeai.app.invocations.primitives import FloatCollectionOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -62,7 +62,7 @@ class FloatLinearRangeInvocation(BaseInvocation): description="number of values to interpolate over (including start and stop)", ) - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) return FloatCollectionOutput(collection=param_list) @@ -110,7 +110,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())] title="Step Param Easing", tags=["step", "easing"], category="step", - version="1.0.0", + version="1.0.1", ) class StepParamEasingInvocation(BaseInvocation): """Experimental per-step parameter easing for denoising steps""" @@ -130,7 +130,7 @@ class StepParamEasingInvocation(BaseInvocation): # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") show_easing_plot: bool = InputField(default=False, description="show easing plot") - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) # start_step = int(np.floor(self.num_steps * self.start_step_percent)) @@ -149,19 +149,19 @@ class StepParamEasingInvocation(BaseInvocation): postlist = list(num_poststeps * [self.post_end_value]) if log_diagnostics: - context.services.logger.debug("start_step: " + str(start_step)) - context.services.logger.debug("end_step: " + str(end_step)) - context.services.logger.debug("num_easing_steps: " + str(num_easing_steps)) - context.services.logger.debug("num_presteps: " + str(num_presteps)) - context.services.logger.debug("num_poststeps: " + str(num_poststeps)) - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("postlist size: " + str(len(postlist))) - context.services.logger.debug("prelist: " + str(prelist)) - context.services.logger.debug("postlist: " + str(postlist)) + context.logger.debug("start_step: " + str(start_step)) + context.logger.debug("end_step: " + str(end_step)) + context.logger.debug("num_easing_steps: " + str(num_easing_steps)) + context.logger.debug("num_presteps: " + str(num_presteps)) + context.logger.debug("num_poststeps: " + str(num_poststeps)) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist: " + str(prelist)) + context.logger.debug("postlist: " + str(postlist)) easing_class = EASING_FUNCTIONS_MAP[self.easing] if log_diagnostics: - context.services.logger.debug("easing class: " + str(easing_class)) + context.logger.debug("easing class: " + str(easing_class)) easing_list = [] if self.mirror: # "expected" mirroring # if number of steps is even, squeeze duration down to (number_of_steps)/2 @@ -172,7 +172,7 @@ class StepParamEasingInvocation(BaseInvocation): base_easing_duration = int(np.ceil(num_easing_steps / 2.0)) if log_diagnostics: - context.services.logger.debug("base easing duration: " + str(base_easing_duration)) + context.logger.debug("base easing duration: " + str(base_easing_duration)) even_num_steps = num_easing_steps % 2 == 0 # even number of steps easing_function = easing_class( start=self.start_value, @@ -184,14 +184,14 @@ class StepParamEasingInvocation(BaseInvocation): easing_val = easing_function.ease(step_index) base_easing_vals.append(easing_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) if even_num_steps: mirror_easing_vals = list(reversed(base_easing_vals)) else: mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) if log_diagnostics: - context.services.logger.debug("base easing vals: " + str(base_easing_vals)) - context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) + context.logger.debug("base easing vals: " + str(base_easing_vals)) + context.logger.debug("mirror easing vals: " + str(mirror_easing_vals)) easing_list = base_easing_vals + mirror_easing_vals # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely @@ -226,12 +226,12 @@ class StepParamEasingInvocation(BaseInvocation): step_val = easing_function.ease(step_index) easing_list.append(step_val) if log_diagnostics: - context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) + context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) if log_diagnostics: - context.services.logger.debug("prelist size: " + str(len(prelist))) - context.services.logger.debug("easing_list size: " + str(len(easing_list))) - context.services.logger.debug("postlist size: " + str(len(postlist))) + context.logger.debug("prelist size: " + str(len(prelist))) + context.logger.debug("easing_list size: " + str(len(easing_list))) + context.logger.debug("postlist size: " + str(len(postlist))) param_list = prelist + easing_list + postlist diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 22f03454a5..ee04345eed 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -1,16 +1,26 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) -from typing import Optional, Tuple +from typing import Optional import torch -from pydantic import BaseModel, Field -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent +from invokeai.app.invocations.fields import ( + ColorField, + ConditioningField, + DenoiseMaskField, + FieldDescriptions, + ImageField, + Input, + InputField, + LatentsField, + OutputField, + UIComponent, +) +from invokeai.app.services.images.images_common import ImageDTO from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -49,7 +59,7 @@ class BooleanInvocation(BaseInvocation): value: bool = InputField(default=False, description="The boolean value") - def invoke(self, context: InvocationContext) -> BooleanOutput: + def invoke(self, context) -> BooleanOutput: return BooleanOutput(value=self.value) @@ -65,7 +75,7 @@ class BooleanCollectionInvocation(BaseInvocation): collection: list[bool] = InputField(default=[], description="The collection of boolean values") - def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: + def invoke(self, context) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -98,7 +108,7 @@ class IntegerInvocation(BaseInvocation): value: int = InputField(default=0, description="The integer value") - def invoke(self, context: InvocationContext) -> IntegerOutput: + def invoke(self, context) -> IntegerOutput: return IntegerOutput(value=self.value) @@ -114,7 +124,7 @@ class IntegerCollectionInvocation(BaseInvocation): collection: list[int] = InputField(default=[], description="The collection of integer values") - def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: + def invoke(self, context) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -145,7 +155,7 @@ class FloatInvocation(BaseInvocation): value: float = InputField(default=0.0, description="The float value") - def invoke(self, context: InvocationContext) -> FloatOutput: + def invoke(self, context) -> FloatOutput: return FloatOutput(value=self.value) @@ -161,7 +171,7 @@ class FloatCollectionInvocation(BaseInvocation): collection: list[float] = InputField(default=[], description="The collection of float values") - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: + def invoke(self, context) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -192,7 +202,7 @@ class StringInvocation(BaseInvocation): value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=self.value) @@ -208,7 +218,7 @@ class StringCollectionInvocation(BaseInvocation): collection: list[str] = InputField(default=[], description="The collection of string values") - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -217,18 +227,6 @@ class StringCollectionInvocation(BaseInvocation): # region Image -class ImageField(BaseModel): - """An image primitive field""" - - image_name: str = Field(description="The name of the image") - - -class BoardField(BaseModel): - """A board primitive field""" - - board_id: str = Field(description="The id of the board") - - @invocation_output("image_output") class ImageOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" @@ -237,6 +235,14 @@ class ImageOutput(BaseInvocationOutput): width: int = OutputField(description="The width of the image in pixels") height: int = OutputField(description="The height of the image in pixels") + @classmethod + def build(cls, image_dto: ImageDTO) -> "ImageOutput": + return cls( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) + @invocation_output("image_collection_output") class ImageCollectionOutput(BaseInvocationOutput): @@ -247,7 +253,7 @@ class ImageCollectionOutput(BaseInvocationOutput): ) -@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0") +@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1") class ImageInvocation( BaseInvocation, ): @@ -255,8 +261,8 @@ class ImageInvocation( image: ImageField = InputField(description="The image to load") - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) return ImageOutput( image=ImageField(image_name=self.image.image_name), @@ -277,7 +283,7 @@ class ImageCollectionInvocation(BaseInvocation): collection: list[ImageField] = InputField(description="The collection of image values") - def invoke(self, context: InvocationContext) -> ImageCollectionOutput: + def invoke(self, context) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) @@ -286,32 +292,24 @@ class ImageCollectionInvocation(BaseInvocation): # region DenoiseMask -class DenoiseMaskField(BaseModel): - """An inpaint mask field""" - - mask_name: str = Field(description="The name of the mask image") - masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents") - - @invocation_output("denoise_mask_output") class DenoiseMaskOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run") + @classmethod + def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput": + return cls( + denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name), + ) + # endregion # region Latents -class LatentsField(BaseModel): - """A latents tensor primitive field""" - - latents_name: str = Field(description="The name of the latents") - seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") - - @invocation_output("latents_output") class LatentsOutput(BaseInvocationOutput): """Base class for nodes that output a single latents tensor""" @@ -322,6 +320,14 @@ class LatentsOutput(BaseInvocationOutput): width: int = OutputField(description=FieldDescriptions.width) height: int = OutputField(description=FieldDescriptions.height) + @classmethod + def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput": + return cls( + latents=LatentsField(latents_name=latents_name, seed=seed), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) + @invocation_output("latents_collection_output") class LatentsCollectionOutput(BaseInvocationOutput): @@ -333,17 +339,17 @@ class LatentsCollectionOutput(BaseInvocationOutput): @invocation( - "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0" + "latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1" ) class LatentsInvocation(BaseInvocation): """A latents tensor primitive value""" latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) - def invoke(self, context: InvocationContext) -> LatentsOutput: - latents = context.services.latents.get(self.latents.latents_name) + def invoke(self, context) -> LatentsOutput: + latents = context.latents.get(self.latents.latents_name) - return build_latents_output(self.latents.latents_name, latents) + return LatentsOutput.build(self.latents.latents_name, latents) @invocation( @@ -360,35 +366,15 @@ class LatentsCollectionInvocation(BaseInvocation): description="The collection of latents tensors", ) - def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: + def invoke(self, context) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) -def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name, seed=seed), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) - - # endregion # region Color -class ColorField(BaseModel): - """A color primitive field""" - - r: int = Field(ge=0, le=255, description="The red component") - g: int = Field(ge=0, le=255, description="The green component") - b: int = Field(ge=0, le=255, description="The blue component") - a: int = Field(ge=0, le=255, description="The alpha component") - - def tuple(self) -> Tuple[int, int, int, int]: - return (self.r, self.g, self.b, self.a) - - @invocation_output("color_output") class ColorOutput(BaseInvocationOutput): """Base class for nodes that output a single color""" @@ -411,7 +397,7 @@ class ColorInvocation(BaseInvocation): color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") - def invoke(self, context: InvocationContext) -> ColorOutput: + def invoke(self, context) -> ColorOutput: return ColorOutput(color=self.color) @@ -420,18 +406,16 @@ class ColorInvocation(BaseInvocation): # region Conditioning -class ConditioningField(BaseModel): - """A conditioning tensor primitive value""" - - conditioning_name: str = Field(description="The name of conditioning tensor") - - @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) + @classmethod + def build(cls, conditioning_name: str) -> "ConditioningOutput": + return cls(conditioning=ConditioningField(conditioning_name=conditioning_name)) + @invocation_output("conditioning_collection_output") class ConditioningCollectionOutput(BaseInvocationOutput): @@ -454,7 +438,7 @@ class ConditioningInvocation(BaseInvocation): conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context) -> ConditioningOutput: return ConditioningOutput(conditioning=self.conditioning) @@ -473,7 +457,7 @@ class ConditioningCollectionInvocation(BaseInvocation): description="The collection of conditioning tensors", ) - def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: + def invoke(self, context) -> ConditioningCollectionOutput: return ConditioningCollectionOutput(collection=self.collection) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 94b4a217ae..4f5ef43a56 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -7,7 +7,7 @@ from pydantic import field_validator from invokeai.app.invocations.primitives import StringCollectionOutput -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, UIComponent @@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation): max_prompts: int = InputField(default=1, description="The number of prompts to generate") combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: if self.combinatorial: generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) @@ -91,7 +91,7 @@ class PromptsFromFileInvocation(BaseInvocation): break return prompts - def invoke(self, context: InvocationContext) -> StringCollectionOutput: + def invoke(self, context) -> StringCollectionOutput: prompts = self.promptsFromFile( self.file_path, self.pre_prompt, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 62df5bc804..75a526cfff 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -4,7 +4,6 @@ from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -30,7 +29,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0") +@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" @@ -39,13 +38,13 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: + def invoke(self, context) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, @@ -116,7 +115,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): title="SDXL Refiner Model", tags=["model", "sdxl", "refiner"], category="model", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" @@ -128,13 +127,13 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: + def invoke(self, context) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions - if not context.services.model_manager.model_exists( + if not context.models.exists( model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index ccbc2f6d92..a4c92d9de5 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -5,7 +5,6 @@ import re from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -33,7 +32,7 @@ class StringSplitNegInvocation(BaseInvocation): string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringPosNegOutput: + def invoke(self, context) -> StringPosNegOutput: p_string = "" n_string = "" brackets_depth = 0 @@ -77,7 +76,7 @@ class StringSplitInvocation(BaseInvocation): default="", description="Delimiter to spilt with. blank will split on the first whitespace" ) - def invoke(self, context: InvocationContext) -> String2Output: + def invoke(self, context) -> String2Output: result = self.string.split(self.delimiter, 1) if len(result) == 2: part1, part2 = result @@ -95,7 +94,7 @@ class StringJoinInvocation(BaseInvocation): string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) @@ -107,7 +106,7 @@ class StringJoinThreeInvocation(BaseInvocation): string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or ""))) @@ -126,7 +125,7 @@ class StringReplaceInvocation(BaseInvocation): default=False, description="Use search string as a regex expression (non regex is case insensitive)" ) - def invoke(self, context: InvocationContext) -> StringOutput: + def invoke(self, context) -> StringOutput: pattern = self.search_string or "" new_string = self.string or "" if len(pattern) > 0: diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 66ac87c37b..74a098a501 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -5,13 +5,11 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField -from invokeai.app.invocations.primitives import ImageField +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.backend.model_management.models.base import BaseModelType @@ -91,7 +89,7 @@ class T2IAdapterInvocation(BaseInvocation): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context: InvocationContext) -> T2IAdapterOutput: + def invoke(self, context) -> T2IAdapterOutput: return T2IAdapterOutput( t2i_adapter=T2IAdapterField( image=self.image, diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index bdc23ef6ed..dd34c3dc09 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -8,13 +8,12 @@ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, Classification, - InvocationContext, + WithMetadata, invocation, invocation_output, ) -from invokeai.app.invocations.fields import Input, InputField, OutputField, WithMetadata -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, calc_tiles_min_overlap, @@ -58,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation): description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", ) - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_with_overlap( image_height=self.image_height, image_width=self.image_width, @@ -101,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation): description="The overlap, in pixels, between adjacent tiles.", ) - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_even_split( image_height=self.image_height, image_width=self.image_width, @@ -131,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.") - def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + def invoke(self, context) -> CalculateImageTilesOutput: tiles = calc_tiles_min_overlap( image_height=self.image_height, image_width=self.image_width, @@ -176,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation): tile: Tile = InputField(description="The tile to split into properties.") - def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: + def invoke(self, context) -> TileToPropertiesOutput: return TileToPropertiesOutput( coords_left=self.tile.coords.left, coords_right=self.tile.coords.right, @@ -213,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation): image: ImageField = InputField(description="The tile image.") tile: Tile = InputField(description="The tile properties.") - def invoke(self, context: InvocationContext) -> PairTileImageOutput: + def invoke(self, context) -> PairTileImageOutput: return PairTileImageOutput( tile_with_image=TileWithImage( tile=self.tile, @@ -249,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context) -> ImageOutput: images = [twi.image for twi in self.tiles_with_images] tiles = [twi.tile for twi in self.tiles_with_images] @@ -265,7 +264,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): # existed in memory at an earlier point in the graph. tile_np_images: list[np.ndarray] = [] for image in images: - pil_image = context.services.images.get_pil_image(image.image_name) + pil_image = context.images.get_pil(image.image_name) pil_image = pil_image.convert("RGB") tile_np_images.append(np.array(pil_image)) @@ -288,18 +287,5 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): # Convert into a PIL image and save pil_image = Image.fromarray(np_image) - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + image_dto = context.images.save(image=pil_image) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 2cab279a9f..ef17480986 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -8,13 +8,13 @@ import torch from PIL import Image from pydantic import ConfigDict -from invokeai.app.invocations.primitives import ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.fields import ImageField +from invokeai.app.invocations.primitives import ImageOutput from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device -from .baseinvocation import BaseInvocation, InvocationContext, invocation +from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata # TODO: Populate this from disk? @@ -30,7 +30,7 @@ if choose_torch_device() == torch.device("mps"): from torch import mps -@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0") +@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1") class ESRGANInvocation(BaseInvocation, WithMetadata): """Upscales an image using RealESRGAN.""" @@ -42,9 +42,9 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): model_config = ConfigDict(protected_namespaces=()) - def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - models_path = context.services.configuration.models_path + def invoke(self, context) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + models_path = context.config.get().models_path rrdbnet_model = None netscale = None @@ -88,7 +88,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): netscale = 2 else: msg = f"Invalid RealESRGAN model: {self.model_name}" - context.services.logger.error(msg) + context.logger.error(msg) raise ValueError(msg) esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") @@ -111,19 +111,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): if choose_torch_device() == torch.device("mps"): mps.empty_cache() - image_dto = context.services.images.create( - image=pil_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) + image_dto = context.images.save(image=pil_image) - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) + return ImageOutput.build(image_dto) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index e9365f3349..ad08ae0395 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -55,7 +55,7 @@ class EventServiceBase: queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str, - node: dict, + node_id: str, source_node_id: str, progress_image: Optional[ProgressImage], step: int, @@ -70,7 +70,7 @@ class EventServiceBase: "queue_item_id": queue_item_id, "queue_batch_id": queue_batch_id, "graph_execution_state_id": graph_execution_state_id, - "node_id": node.get("id"), + "node_id": node_id, "source_node_id": source_node_id, "progress_image": progress_image.model_dump() if progress_image is not None else None, "step": step, diff --git a/invokeai/app/services/invocation_processor/invocation_processor_default.py b/invokeai/app/services/invocation_processor/invocation_processor_default.py index 54342c0da1..d2ebe235e6 100644 --- a/invokeai/app/services/invocation_processor/invocation_processor_default.py +++ b/invokeai/app/services/invocation_processor/invocation_processor_default.py @@ -5,11 +5,11 @@ from threading import BoundedSemaphore, Event, Thread from typing import Optional import invokeai.backend.util.logging as logger -from invokeai.app.invocations.baseinvocation import InvocationContext from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem from invokeai.app.services.invocation_stats.invocation_stats_common import ( GESStatsNotFoundError, ) +from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler from ..invoker import Invoker @@ -131,16 +131,20 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # which handles a few things: # - nodes that require a value, but get it only from a connection # - referencing the invocation cache instead of executing the node - outputs = invocation.invoke_internal( - InvocationContext( - services=self.__invoker.services, - graph_execution_state_id=graph_execution_state.id, - queue_item_id=queue_item.session_queue_item_id, - queue_id=queue_item.session_queue_id, - queue_batch_id=queue_item.session_queue_batch_id, - workflow=queue_item.workflow, - ) + context_data = InvocationContextData( + invocation=invocation, + session_id=graph_id, + workflow=queue_item.workflow, + source_node_id=source_node_id, + queue_id=queue_item.session_queue_id, + queue_item_id=queue_item.session_queue_item_id, + batch_id=queue_item.session_queue_batch_id, ) + context = build_invocation_context( + services=self.__invoker.services, + context_data=context_data, + ) + outputs = invocation.invoke_internal(context=context, services=self.__invoker.services) # Check queue to see if this is canceled, and skip if so if self.__invoker.services.queue.is_canceled(graph_execution_state.id): diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index 4c2fc4c085..a9b53ae224 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -5,11 +5,12 @@ from __future__ import annotations from abc import ABC, abstractmethod from logging import Logger from pathlib import Path -from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union +from typing import Callable, List, Literal, Optional, Tuple, Union from pydantic import Field from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_management import ( AddModelResult, BaseModelType, @@ -21,9 +22,6 @@ from invokeai.backend.model_management import ( ) from invokeai.backend.model_management.model_cache import CacheStats -if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext - class ModelManagerServiceBase(ABC): """Responsible for managing models on disk and in memory""" @@ -49,8 +47,7 @@ class ModelManagerServiceBase(ABC): base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - node: Optional[BaseInvocation] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> ModelInfo: """Retrieve the indicated model with name and type. submodel can be used to get a part (such as the vae) diff --git a/invokeai/app/services/model_manager/model_manager_common.py b/invokeai/app/services/model_manager/model_manager_common.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index cdb3e59a91..b641dd3f1e 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -11,6 +11,8 @@ from pydantic import Field from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_management import ( AddModelResult, BaseModelType, @@ -30,7 +32,7 @@ from invokeai.backend.util import choose_precision, choose_torch_device from .model_manager_base import ModelManagerServiceBase if TYPE_CHECKING: - from invokeai.app.invocations.baseinvocation import InvocationContext + pass # simple implementation @@ -86,13 +88,16 @@ class ModelManagerService(ModelManagerServiceBase): ) logger.info("Model manager service initialized") + def start(self, invoker: Invoker) -> None: + self._invoker: Optional[Invoker] = invoker + def get_model( self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, - context: Optional[InvocationContext] = None, + context_data: Optional[InvocationContextData] = None, ) -> ModelInfo: """ Retrieve the indicated model. submodel can be used to get a @@ -100,9 +105,9 @@ class ModelManagerService(ModelManagerServiceBase): """ # we can emit model loading events if we are executing with access to the invocation context - if context: + if context_data is not None: self._emit_load_event( - context=context, + context_data=context_data, model_name=model_name, base_model=base_model, model_type=model_type, @@ -116,9 +121,9 @@ class ModelManagerService(ModelManagerServiceBase): submodel, ) - if context: + if context_data is not None: self._emit_load_event( - context=context, + context_data=context_data, model_name=model_name, base_model=base_model, model_type=model_type, @@ -263,22 +268,25 @@ class ModelManagerService(ModelManagerServiceBase): def _emit_load_event( self, - context: InvocationContext, + context_data: InvocationContextData, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None, model_info: Optional[ModelInfo] = None, ): - if context.services.queue.is_canceled(context.graph_execution_state_id): + if self._invoker is None: + return + + if self._invoker.services.queue.is_canceled(context_data.session_id): raise CanceledException() if model_info: - context.services.events.emit_model_load_completed( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_completed( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -286,11 +294,11 @@ class ModelManagerService(ModelManagerServiceBase): model_info=model_info, ) else: - context.services.events.emit_model_load_started( - queue_id=context.queue_id, - queue_item_id=context.queue_item_id, - queue_batch_id=context.queue_batch_id, - graph_execution_state_id=context.graph_execution_state_id, + self._invoker.services.events.emit_model_load_started( + queue_id=context_data.queue_id, + queue_item_id=context_data.queue_item_id, + queue_batch_id=context_data.batch_id, + graph_execution_state_id=context_data.session_id, model_name=model_name, base_model=base_model, model_type=model_type, diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index ba05b050c5..c0699eb96b 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -13,7 +13,6 @@ from invokeai.app.invocations import * # noqa: F401 F403 from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationContext, invocation, invocation_output, ) @@ -202,7 +201,7 @@ class GraphInvocation(BaseInvocation): # TODO: figure out how to create a default here graph: "Graph" = InputField(description="The graph to run", default=None) - def invoke(self, context: InvocationContext) -> GraphInvocationOutput: + def invoke(self, context) -> GraphInvocationOutput: """Invoke with provided services and return outputs.""" return GraphInvocationOutput() @@ -228,7 +227,7 @@ class IterateInvocation(BaseInvocation): ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) - def invoke(self, context: InvocationContext) -> IterateInvocationOutput: + def invoke(self, context) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection)) @@ -255,7 +254,7 @@ class CollectInvocation(BaseInvocation): description="The collection, will be provided on execution", default=[], ui_hidden=True ) - def invoke(self, context: InvocationContext) -> CollectInvocationOutput: + def invoke(self, context) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c0aaac54f8..b68e521c73 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -6,8 +6,7 @@ from PIL.Image import Image from pydantic import ConfigDict from torch import Tensor -from invokeai.app.invocations.compel import ConditioningFieldData -from invokeai.app.invocations.fields import MetadataField, WithMetadata +from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -245,13 +244,15 @@ class ConditioningInterface: ) return name - def get(conditioning_name: str) -> Tensor: + def get(conditioning_name: str) -> ConditioningFieldData: """ Gets conditioning data by name. :param conditioning_name: The name of the conditioning data to get. """ - return services.latents.get(conditioning_name) + # TODO(sm): We are (ab)using the latents storage service as a general pickle storage + # service, but it is typed as returning tensors, so we need to ignore the type here. + return services.latents.get(conditioning_name) # type: ignore [return-value] self.save = save self.get = get diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 5cc3caa9ba..d83b380d95 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,25 +1,18 @@ -from typing import Protocol +from typing import TYPE_CHECKING import torch from PIL import Image -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage -from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC -from invokeai.app.services.shared.invocation_context import InvocationContextData from ...backend.model_management.models import BaseModelType from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.util.util import image_to_dataURL - -class StepCallback(Protocol): - def __call__( - self, - intermediate_state: PipelineIntermediateState, - base_model: BaseModelType, - ) -> None: - ... +if TYPE_CHECKING: + from invokeai.app.services.events.events_base import EventServiceBase + from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC + from invokeai.app.services.shared.invocation_context import InvocationContextData def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): @@ -38,11 +31,11 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix= def stable_diffusion_step_callback( - context_data: InvocationContextData, + context_data: "InvocationContextData", intermediate_state: PipelineIntermediateState, base_model: BaseModelType, - invocation_queue: InvocationQueueABC, - events: EventServiceBase, + invocation_queue: "InvocationQueueABC", + events: "EventServiceBase", ) -> None: if invocation_queue.is_canceled(context_data.session_id): raise CanceledException