diff --git a/docs/contributing/INVOCATIONS.md b/docs/contributing/INVOCATIONS.md index 5d9a3690ba..ce1ee9e808 100644 --- a/docs/contributing/INVOCATIONS.md +++ b/docs/contributing/INVOCATIONS.md @@ -174,7 +174,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context): + def invoke(self, context: InvocationContext): pass ``` @@ -203,7 +203,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pass ``` @@ -229,7 +229,7 @@ class ResizeInvocation(BaseInvocation): width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image") height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: # Load the input image as a PIL image image = context.images.get_pil(self.image.image_name) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index f5709b4ba3..e02291980f 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -5,6 +5,7 @@ import numpy as np from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.primitives import IntegerCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from .baseinvocation import BaseInvocation, invocation @@ -27,7 +28,7 @@ class RangeInvocation(BaseInvocation): raise ValueError("stop must be greater than start") return v - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step))) @@ -45,7 +46,7 @@ class RangeOfSizeInvocation(BaseInvocation): size: int = InputField(default=1, gt=0, description="The number of values") step: int = InputField(default=1, description="The step of the range") - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput( collection=list(range(self.start, self.start + (self.step * self.size), self.step)) ) @@ -72,6 +73,6 @@ class RandomRangeInvocation(BaseInvocation): description="The seed for the RNG (omit for random)", ) - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: rng = np.random.default_rng(self.seed) return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 94caf4128d..978c6dcb17 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional, Union import torch from compel import Compel, ReturnedEmbeddingsType @@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import ( UIComponent, ) from invokeai.app.invocations.primitives import ConditioningOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, @@ -31,10 +32,7 @@ from .baseinvocation import ( ) from .model import ClipField -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext - - # unconditioned: Optional[torch.Tensor] +# unconditioned: Optional[torch.Tensor] # class ConditioningAlgo(str, Enum): @@ -65,7 +63,7 @@ class CompelInvocation(BaseInvocation): ) @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) @@ -148,7 +146,7 @@ class CompelInvocation(BaseInvocation): class SDXLPromptInvocationBase: def run_clip_compel( self, - context: "InvocationContext", + context: InvocationContext, clip_field: ClipField, prompt: str, get_pooled: bool, @@ -288,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: c1, c1_pooled, ec1 = self.run_clip_compel( context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True ) @@ -373,7 +371,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) @@ -418,7 +416,7 @@ class ClipSkipInvocation(BaseInvocation): clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) - def invoke(self, context) -> ClipSkipInvocationOutput: + def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers return ClipSkipInvocationOutput( clip=self.clip, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index e993ceffde..f8bdf14117 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -28,6 +28,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector from invokeai.backend.model_management.models.base import BaseModelType @@ -119,7 +120,7 @@ class ControlNetInvocation(BaseInvocation): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> ControlOutput: + def invoke(self, context: InvocationContext) -> ControlOutput: return ControlOutput( control=ControlField( image=self.image, @@ -143,7 +144,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata): # superclass just passes through image without processing return image - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: raw_image = context.images.get_pil(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 375b18f9c5..1ebabf5e06 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -7,6 +7,7 @@ from PIL import Image, ImageOps from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithMetadata @@ -19,7 +20,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to inpaint") mask: ImageField = InputField(description="The mask to use when inpainting") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) mask = context.images.get_pil(self.mask.image_name) diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index dad6308981..a1702d6517 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -1,7 +1,7 @@ import math import re from pathlib import Path -from typing import TYPE_CHECKING, Optional, TypedDict +from typing import Optional, TypedDict import cv2 import numpy as np @@ -19,9 +19,7 @@ from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory - -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.shared.invocation_context import InvocationContext @invocation_output("face_mask_output") @@ -176,7 +174,7 @@ def prepare_faces_list( def generate_face_box_mask( - context: "InvocationContext", + context: InvocationContext, minimum_confidence: float, x_offset: float, y_offset: float, @@ -275,7 +273,7 @@ def generate_face_box_mask( def extract_face( - context: "InvocationContext", + context: InvocationContext, image: ImageType, face: FaceResultData, padding: int, @@ -356,7 +354,7 @@ def extract_face( def get_faces_list( - context: "InvocationContext", + context: InvocationContext, image: ImageType, should_chunk: bool, minimum_confidence: float, @@ -458,7 +456,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]: + def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]: all_faces = get_faces_list( context=context, image=image, @@ -485,7 +483,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): return face_data - def invoke(self, context) -> FaceOffOutput: + def invoke(self, context: InvocationContext) -> FaceOffOutput: image = context.images.get_pil(self.image.image_name) result = self.faceoff(context=context, image=image) @@ -543,7 +541,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata): raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")') return v - def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult: + def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult: all_faces = get_faces_list( context=context, image=image, @@ -600,7 +598,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata): mask=mask_pil, ) - def invoke(self, context) -> FaceMaskOutput: + def invoke(self, context: InvocationContext) -> FaceMaskOutput: image = context.images.get_pil(self.image.image_name) result = self.facemask(context=context, image=image) @@ -633,7 +631,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.", ) - def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType: + def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType: image = image.copy() all_faces = get_faces_list( @@ -674,7 +672,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata): return image - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) result_image = self.faceidentifier(context=context, image=image) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 3b8b0b4b80..7b74e4d96d 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import ( ) from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker @@ -34,7 +35,7 @@ class ShowImageInvocation(BaseInvocation): image: ImageField = InputField(description="The image to show") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image.show() @@ -62,7 +63,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata): mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image") color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple()) image_dto = context.images.save(image=image) @@ -86,7 +87,7 @@ class ImageCropInvocation(BaseInvocation, WithMetadata): width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) @@ -125,7 +126,7 @@ class CenterPadCropInvocation(BaseInvocation): description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) # Calculate and create new image dimensions @@ -161,7 +162,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata): y: int = InputField(default=0, description="The top y coordinate at which to paste the image") crop: bool = InputField(default=False, description="Crop to base image dimensions") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.images.get_pil(self.base_image.image_name) image = context.images.get_pil(self.image.image_name) mask = None @@ -201,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to create the mask from") invert: bool = InputField(default=False, description="Whether or not to invert the mask") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_mask = image.split()[-1] @@ -226,7 +227,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata): image1: ImageField = InputField(description="The first image to multiply") image2: ImageField = InputField(description="The second image to multiply") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image1 = context.images.get_pil(self.image1.image_name) image2 = context.images.get_pil(self.image2.image_name) @@ -253,7 +254,7 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to get the channel from") channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) channel_image = image.getchannel(self.channel) @@ -279,7 +280,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to convert") mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) converted_image = image.convert(self.mode) @@ -304,7 +305,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata): # Metadata blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) blur = ( @@ -338,7 +339,7 @@ class UnsharpMaskInvocation(BaseInvocation, WithMetadata): def array_from_pil(self, img): return numpy.array(img) / 255 - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) mode = image.mode @@ -401,7 +402,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata): height: int = InputField(default=512, gt=0, description="The height to resize to (px)") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -434,7 +435,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata): ) resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -465,7 +466,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 @@ -492,7 +493,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata): min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) @@ -517,7 +518,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) logger = context.logger @@ -553,7 +554,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to check") text: str = InputField(default="InvokeAI", description="Watermark text") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) new_image = InvisibleWatermark.add_watermark(image, self.text) image_dto = context.images.save(image=new_image) @@ -579,7 +580,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata): description="Second threshold for the hysteresis procedure in Canny edge detection" ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: mask = context.images.get_pil(self.image.image_name).convert("L") npimg = numpy.asarray(mask, dtype=numpy.uint8) @@ -613,7 +614,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata): mask1: ImageField = InputField(description="The first mask to combine") mask2: ImageField = InputField(description="The second image to combine") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: mask1 = context.images.get_pil(self.mask1.image_name).convert("L") mask2 = context.images.get_pil(self.mask2.image_name).convert("L") @@ -642,7 +643,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata): mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") mask_blur_radius: float = InputField(default=8, description="Mask blur radius") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_init_mask = None if self.mask is not None: pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L") @@ -741,7 +742,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to adjust") hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # Convert image to HSV color space @@ -831,7 +832,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata): channel: COLOR_CHANNELS = InputField(description="Which channel to adjust") offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple @@ -888,7 +889,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata): scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.") invert_channel: bool = InputField(default=False, description="Invert the channel after scaling") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.images.get_pil(self.image.image_name) # extract the channel and mode from the input and reference tuple @@ -931,7 +932,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None) @@ -953,7 +954,7 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description=FieldDescriptions.image) board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image_dto = context.images.get_dto(self.image.image_name) image_dto = context.images.update( diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 159bdb5f7a..b007edd9e4 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -8,6 +8,7 @@ from PIL import Image, ImageOps from invokeai.app.invocations.fields import ColorField, ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA @@ -129,7 +130,7 @@ class InfillColorInvocation(BaseInvocation, WithMetadata): description="The color to use to infill", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) @@ -155,7 +156,7 @@ class InfillTileInvocation(BaseInvocation, WithMetadata): description="The seed to use for tile generation (omit for random)", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) @@ -176,7 +177,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata): downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill") resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name).convert("RGBA") resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -213,7 +214,7 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to infill") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = infill_lama(image.copy()) @@ -229,7 +230,7 @@ class CV2InfillInvocation(BaseInvocation, WithMetadata): image: ImageField = InputField(description="The image to infill") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) infilled = infill_cv2(image.copy()) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index b836be04b5..845fcfa284 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -13,6 +13,7 @@ from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_management.models.base import BaseModelType, ModelType from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id @@ -92,7 +93,7 @@ class IPAdapterInvocation(BaseInvocation): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> IPAdapterOutput: + def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_info( self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 0127a6521e..2cc84f80a7 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,7 +3,7 @@ import math from contextlib import ExitStack from functools import singledispatchmethod -from typing import TYPE_CHECKING, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union import einops import numpy as np @@ -42,6 +42,7 @@ from invokeai.app.invocations.primitives import ( LatentsOutput, ) from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings @@ -70,9 +71,6 @@ from .baseinvocation import ( from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField -if TYPE_CHECKING: - from invokeai.app.services.shared.invocation_context import InvocationContext - if choose_torch_device() == torch.device("mps"): from torch import mps @@ -177,7 +175,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation): def get_scheduler( - context: "InvocationContext", + context: InvocationContext, scheduler_info: ModelInfo, scheduler_name: str, seed: int, @@ -300,7 +298,7 @@ class DenoiseLatentsInvocation(BaseInvocation): def get_conditioning_data( self, - context: "InvocationContext", + context: InvocationContext, scheduler, unet, seed, @@ -369,7 +367,7 @@ class DenoiseLatentsInvocation(BaseInvocation): def prep_control_data( self, - context: "InvocationContext", + context: InvocationContext, control_input: Union[ControlField, List[ControlField]], latents_shape: List[int], exit_stack: ExitStack, @@ -442,7 +440,7 @@ class DenoiseLatentsInvocation(BaseInvocation): def prep_ip_adapter_data( self, - context: "InvocationContext", + context: InvocationContext, ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], conditioning_data: ConditioningData, exit_stack: ExitStack, @@ -509,7 +507,7 @@ class DenoiseLatentsInvocation(BaseInvocation): def run_t2i_adapters( self, - context: "InvocationContext", + context: InvocationContext, t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], latents_shape: list[int], do_classifier_free_guidance: bool, @@ -618,7 +616,7 @@ class DenoiseLatentsInvocation(BaseInvocation): return num_inference_steps, timesteps, init_timestep - def prep_inpaint_mask(self, context: "InvocationContext", latents): + def prep_inpaint_mask(self, context: InvocationContext, latents): if self.denoise_mask is None: return None, None diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index d2dbf04981..83a092be69 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -7,6 +7,7 @@ from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation @@ -18,7 +19,7 @@ class AddInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a + self.b) @@ -29,7 +30,7 @@ class SubtractInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a - self.b) @@ -40,7 +41,7 @@ class MultiplyInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.a * self.b) @@ -51,7 +52,7 @@ class DivideInvocation(BaseInvocation): a: int = InputField(default=0, description=FieldDescriptions.num_1) b: int = InputField(default=0, description=FieldDescriptions.num_2) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=int(self.a / self.b)) @@ -69,7 +70,7 @@ class RandomIntInvocation(BaseInvocation): low: int = InputField(default=0, description=FieldDescriptions.inclusive_low) high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=np.random.randint(self.low, self.high)) @@ -88,7 +89,7 @@ class RandomFloatInvocation(BaseInvocation): high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high) decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places) - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: random_float = np.random.uniform(self.low, self.high) rounded_float = round(random_float, self.decimals) return FloatOutput(value=rounded_float) @@ -110,7 +111,7 @@ class FloatToIntegerInvocation(BaseInvocation): default="Nearest", description="The method to use for rounding" ) - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: if self.method == "Nearest": return IntegerOutput(value=round(self.value / self.multiple) * self.multiple) elif self.method == "Floor": @@ -128,7 +129,7 @@ class RoundInvocation(BaseInvocation): value: float = InputField(default=0, description="The float value") decimals: int = InputField(default=0, description="The number of decimal places") - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(value=round(self.value, self.decimals)) @@ -196,7 +197,7 @@ class IntegerMathInvocation(BaseInvocation): raise ValueError("Result of exponentiation is not an integer") return v - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return IntegerOutput(value=self.a + self.b) @@ -270,7 +271,7 @@ class FloatMathInvocation(BaseInvocation): raise ValueError("Root operation resulted in a complex number") return v - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: # Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9 if self.operation == "ADD": return FloatOutput(value=self.a + self.b) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 9d74abd8c1..58edfab711 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -20,6 +20,7 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.ip_adapter import IPAdapterModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.services.shared.invocation_context import InvocationContext from ...version import __version__ @@ -64,7 +65,7 @@ class MetadataItemInvocation(BaseInvocation): label: str = InputField(description=FieldDescriptions.metadata_item_label) value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any) - def invoke(self, context) -> MetadataItemOutput: + def invoke(self, context: InvocationContext) -> MetadataItemOutput: return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value)) @@ -81,7 +82,7 @@ class MetadataInvocation(BaseInvocation): description=FieldDescriptions.metadata_item_polymorphic ) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: if isinstance(self.items, MetadataItemField): # single metadata item data = {self.items.label: self.items.value} @@ -100,7 +101,7 @@ class MergeMetadataInvocation(BaseInvocation): collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: data = {} for item in self.collection: data.update(item.model_dump()) @@ -218,7 +219,7 @@ class CoreMetadataInvocation(BaseInvocation): description="The start value used for refiner denoising", ) - def invoke(self, context) -> MetadataOutput: + def invoke(self, context: InvocationContext) -> MetadataOutput: """Collects and outputs a CoreMetadata object""" return MetadataOutput( diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index f81e559e44..6a1fd6d36b 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -4,6 +4,7 @@ from typing import List, Optional from pydantic import BaseModel, ConfigDict, Field from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.shared.models import FreeUConfig from ...backend.model_management import BaseModelType, ModelType, SubModelType @@ -109,7 +110,7 @@ class MainModelLoaderInvocation(BaseInvocation): model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - def invoke(self, context) -> ModelLoaderOutput: + def invoke(self, context: InvocationContext) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -221,7 +222,7 @@ class LoraLoaderInvocation(BaseInvocation): title="CLIP", ) - def invoke(self, context) -> LoraLoaderOutput: + def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") @@ -310,7 +311,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation): title="CLIP 2", ) - def invoke(self, context) -> SDXLLoraLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: raise Exception("No LoRA provided") @@ -393,7 +394,7 @@ class VaeLoaderInvocation(BaseInvocation): title="VAE", ) - def invoke(self, context) -> VAEOutput: + def invoke(self, context: InvocationContext) -> VAEOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae @@ -448,7 +449,7 @@ class SeamlessModeInvocation(BaseInvocation): seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless") seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless") - def invoke(self, context) -> SeamlessModeOutput: + def invoke(self, context: InvocationContext) -> SeamlessModeOutput: # Conditionally append 'x' and 'y' based on seamless_x and seamless_y unet = copy.deepcopy(self.unet) vae = copy.deepcopy(self.vae) @@ -484,6 +485,6 @@ class FreeUInvocation(BaseInvocation): s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1) s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2) - def invoke(self, context) -> UNetOutput: + def invoke(self, context: InvocationContext) -> UNetOutput: self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2) return UNetOutput(unet=self.unet) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 41641152f0..78f13cc52d 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -5,6 +5,7 @@ import torch from pydantic import field_validator from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype @@ -112,7 +113,7 @@ class NoiseInvocation(BaseInvocation): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) - def invoke(self, context) -> NoiseOutput: + def invoke(self, context: InvocationContext) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index a1e318a380..e7b4d3d9fc 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation): prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.model_dump(), ) @@ -201,7 +201,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation): # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) @@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): ) # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) if self.vae.vae.submodel != SubModelType.VaeDecoder: @@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation): description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel ) - def invoke(self, context) -> ONNXModelLoaderOutput: + def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.ONNX diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index bf59e87d27..6845637de9 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -40,6 +40,7 @@ from easing_functions import ( from matplotlib.ticker import MaxNLocator from invokeai.app.invocations.primitives import FloatCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField @@ -62,7 +63,7 @@ class FloatLinearRangeInvocation(BaseInvocation): description="number of values to interpolate over (including start and stop)", ) - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) return FloatCollectionOutput(collection=param_list) @@ -130,7 +131,7 @@ class StepParamEasingInvocation(BaseInvocation): # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") show_easing_plot: bool = InputField(default=False, description="show easing plot") - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) # start_step = int(np.floor(self.num_steps * self.start_step_percent)) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index ee04345eed..c90d3230b2 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -17,6 +17,7 @@ from invokeai.app.invocations.fields import ( UIComponent, ) from invokeai.app.services.images.images_common import ImageDTO +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import ( BaseInvocation, @@ -59,7 +60,7 @@ class BooleanInvocation(BaseInvocation): value: bool = InputField(default=False, description="The boolean value") - def invoke(self, context) -> BooleanOutput: + def invoke(self, context: InvocationContext) -> BooleanOutput: return BooleanOutput(value=self.value) @@ -75,7 +76,7 @@ class BooleanCollectionInvocation(BaseInvocation): collection: list[bool] = InputField(default=[], description="The collection of boolean values") - def invoke(self, context) -> BooleanCollectionOutput: + def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -108,7 +109,7 @@ class IntegerInvocation(BaseInvocation): value: int = InputField(default=0, description="The integer value") - def invoke(self, context) -> IntegerOutput: + def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=self.value) @@ -124,7 +125,7 @@ class IntegerCollectionInvocation(BaseInvocation): collection: list[int] = InputField(default=[], description="The collection of integer values") - def invoke(self, context) -> IntegerCollectionOutput: + def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -155,7 +156,7 @@ class FloatInvocation(BaseInvocation): value: float = InputField(default=0.0, description="The float value") - def invoke(self, context) -> FloatOutput: + def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(value=self.value) @@ -171,7 +172,7 @@ class FloatCollectionInvocation(BaseInvocation): collection: list[float] = InputField(default=[], description="The collection of float values") - def invoke(self, context) -> FloatCollectionOutput: + def invoke(self, context: InvocationContext) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -202,7 +203,7 @@ class StringInvocation(BaseInvocation): value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=self.value) @@ -218,7 +219,7 @@ class StringCollectionInvocation(BaseInvocation): collection: list[str] = InputField(default=[], description="The collection of string values") - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -261,7 +262,7 @@ class ImageInvocation( image: ImageField = InputField(description="The image to load") - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) return ImageOutput( @@ -283,7 +284,7 @@ class ImageCollectionInvocation(BaseInvocation): collection: list[ImageField] = InputField(description="The collection of image values") - def invoke(self, context) -> ImageCollectionOutput: + def invoke(self, context: InvocationContext) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.collection) @@ -346,7 +347,7 @@ class LatentsInvocation(BaseInvocation): latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection) - def invoke(self, context) -> LatentsOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.latents.get(self.latents.latents_name) return LatentsOutput.build(self.latents.latents_name, latents) @@ -366,7 +367,7 @@ class LatentsCollectionInvocation(BaseInvocation): description="The collection of latents tensors", ) - def invoke(self, context) -> LatentsCollectionOutput: + def invoke(self, context: InvocationContext) -> LatentsCollectionOutput: return LatentsCollectionOutput(collection=self.collection) @@ -397,7 +398,7 @@ class ColorInvocation(BaseInvocation): color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value") - def invoke(self, context) -> ColorOutput: + def invoke(self, context: InvocationContext) -> ColorOutput: return ColorOutput(color=self.color) @@ -438,7 +439,7 @@ class ConditioningInvocation(BaseInvocation): conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection) - def invoke(self, context) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> ConditioningOutput: return ConditioningOutput(conditioning=self.conditioning) @@ -457,7 +458,7 @@ class ConditioningCollectionInvocation(BaseInvocation): description="The collection of conditioning tensors", ) - def invoke(self, context) -> ConditioningCollectionOutput: + def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput: return ConditioningCollectionOutput(collection=self.collection) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 4f5ef43a56..234743a003 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -6,6 +6,7 @@ from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPrompt from pydantic import field_validator from invokeai.app.invocations.primitives import StringCollectionOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from .baseinvocation import BaseInvocation, invocation from .fields import InputField, UIComponent @@ -29,7 +30,7 @@ class DynamicPromptInvocation(BaseInvocation): max_prompts: int = InputField(default=1, description="The number of prompts to generate") combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: if self.combinatorial: generator = CombinatorialPromptGenerator() prompts = generator.generate(self.prompt, max_prompts=self.max_prompts) @@ -91,7 +92,7 @@ class PromptsFromFileInvocation(BaseInvocation): break return prompts - def invoke(self, context) -> StringCollectionOutput: + def invoke(self, context: InvocationContext) -> StringCollectionOutput: prompts = self.promptsFromFile( self.file_path, self.pre_prompt, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 75a526cfff..8d51674a04 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,4 +1,5 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext from ...backend.model_management import ModelType, SubModelType from .baseinvocation import ( @@ -38,7 +39,7 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context) -> SDXLModelLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -127,7 +128,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ) # TODO: precision? - def invoke(self, context) -> SDXLRefinerModelLoaderOutput: + def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index a4c92d9de5..182c976cd7 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -2,6 +2,8 @@ import re +from invokeai.app.services.shared.invocation_context import InvocationContext + from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -32,7 +34,7 @@ class StringSplitNegInvocation(BaseInvocation): string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringPosNegOutput: + def invoke(self, context: InvocationContext) -> StringPosNegOutput: p_string = "" n_string = "" brackets_depth = 0 @@ -76,7 +78,7 @@ class StringSplitInvocation(BaseInvocation): default="", description="Delimiter to spilt with. blank will split on the first whitespace" ) - def invoke(self, context) -> String2Output: + def invoke(self, context: InvocationContext) -> String2Output: result = self.string.split(self.delimiter, 1) if len(result) == 2: part1, part2 = result @@ -94,7 +96,7 @@ class StringJoinInvocation(BaseInvocation): string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) @@ -106,7 +108,7 @@ class StringJoinThreeInvocation(BaseInvocation): string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea) string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or ""))) @@ -125,7 +127,7 @@ class StringReplaceInvocation(BaseInvocation): default=False, description="Use search string as a regex expression (non regex is case insensitive)" ) - def invoke(self, context) -> StringOutput: + def invoke(self, context: InvocationContext) -> StringOutput: pattern = self.search_string or "" new_string = self.string or "" if len(pattern) > 0: diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 74a098a501..0f4fe66ada 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_management.models.base import BaseModelType @@ -89,7 +90,7 @@ class T2IAdapterInvocation(BaseInvocation): validate_begin_end_step(self.begin_step_percent, self.end_step_percent) return self - def invoke(self, context) -> T2IAdapterOutput: + def invoke(self, context: InvocationContext) -> T2IAdapterOutput: return T2IAdapterOutput( t2i_adapter=T2IAdapterField( image=self.image, diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index 0b4c472696..19ece42376 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -13,6 +13,7 @@ from invokeai.app.invocations.baseinvocation import ( ) from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.tiles.tiles import ( calc_tiles_even_split, calc_tiles_min_overlap, @@ -56,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation): description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount", ) - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_with_overlap( image_height=self.image_height, image_width=self.image_width, @@ -99,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation): description="The overlap, in pixels, between adjacent tiles.", ) - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_even_split( image_height=self.image_height, image_width=self.image_width, @@ -129,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.") - def invoke(self, context) -> CalculateImageTilesOutput: + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: tiles = calc_tiles_min_overlap( image_height=self.image_height, image_width=self.image_width, @@ -174,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation): tile: Tile = InputField(description="The tile to split into properties.") - def invoke(self, context) -> TileToPropertiesOutput: + def invoke(self, context: InvocationContext) -> TileToPropertiesOutput: return TileToPropertiesOutput( coords_left=self.tile.coords.left, coords_right=self.tile.coords.right, @@ -211,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation): image: ImageField = InputField(description="The tile image.") tile: Tile = InputField(description="The tile properties.") - def invoke(self, context) -> PairTileImageOutput: + def invoke(self, context: InvocationContext) -> PairTileImageOutput: return PairTileImageOutput( tile_with_image=TileWithImage( tile=self.tile, @@ -247,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: images = [twi.image for twi in self.tiles_with_images] tiles = [twi.tile for twi in self.tiles_with_images] diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index ef17480986..71ef7ca3aa 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -10,6 +10,7 @@ from pydantic import ConfigDict from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device @@ -42,7 +43,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata): model_config = ConfigDict(protected_namespaces=()) - def invoke(self, context) -> ImageOutput: + def invoke(self, context: InvocationContext) -> ImageOutput: image = context.images.get_pil(self.image.image_name) models_path = context.config.get().models_path diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index c0699eb96b..3df230f5ee 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -17,6 +17,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation_output, ) from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import uuid_string # in 3.10 this would be "from types import NoneType" @@ -201,7 +202,7 @@ class GraphInvocation(BaseInvocation): # TODO: figure out how to create a default here graph: "Graph" = InputField(description="The graph to run", default=None) - def invoke(self, context) -> GraphInvocationOutput: + def invoke(self, context: InvocationContext) -> GraphInvocationOutput: """Invoke with provided services and return outputs.""" return GraphInvocationOutput() @@ -227,7 +228,7 @@ class IterateInvocation(BaseInvocation): ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) - def invoke(self, context) -> IterateInvocationOutput: + def invoke(self, context: InvocationContext) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection)) @@ -254,7 +255,7 @@ class CollectInvocation(BaseInvocation): description="The collection, will be provided on execution", default=[], ui_hidden=True ) - def invoke(self, context) -> CollectInvocationOutput: + def invoke(self, context: InvocationContext) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) diff --git a/tests/aa_nodes/test_nodes.py b/tests/aa_nodes/test_nodes.py index 559457c0e1..aab3d9c7b4 100644 --- a/tests/aa_nodes/test_nodes.py +++ b/tests/aa_nodes/test_nodes.py @@ -8,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import ( ) from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField +from invokeai.app.services.shared.invocation_context import InvocationContext # Define test invocations before importing anything that uses invocations @@ -20,7 +21,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput): class ListPassThroughInvocation(BaseInvocation): collection: list[ImageField] = InputField(default=[]) - def invoke(self, context) -> ListPassThroughInvocationOutput: + def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: return ListPassThroughInvocationOutput(collection=self.collection) @@ -33,13 +34,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput): class PromptTestInvocation(BaseInvocation): prompt: str = InputField(default="") - def invoke(self, context) -> PromptTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: return PromptTestInvocationOutput(prompt=self.prompt) @invocation("test_error", version="1.0.0") class ErrorInvocation(BaseInvocation): - def invoke(self, context) -> PromptTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: raise Exception("This invocation is supposed to fail") @@ -53,7 +54,7 @@ class TextToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") prompt2: str = InputField(default="") - def invoke(self, context) -> ImageTestInvocationOutput: + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -62,7 +63,7 @@ class ImageToImageTestInvocation(BaseInvocation): prompt: str = InputField(default="") image: Union[ImageField, None] = InputField(default=None) - def invoke(self, context) -> ImageTestInvocationOutput: + def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) @@ -75,7 +76,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput): class PromptCollectionTestInvocation(BaseInvocation): collection: list[str] = InputField() - def invoke(self, context) -> PromptCollectionTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) @@ -88,7 +89,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput): class AnyTypeTestInvocation(BaseInvocation): value: Any = InputField(default=None) - def invoke(self, context) -> AnyTypeTestInvocationOutput: + def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput: return AnyTypeTestInvocationOutput(value=self.value) @@ -96,7 +97,7 @@ class AnyTypeTestInvocation(BaseInvocation): class PolymorphicStringTestInvocation(BaseInvocation): value: Union[str, list[str]] = InputField(default="") - def invoke(self, context) -> PromptCollectionTestInvocationOutput: + def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: if isinstance(self.value, str): return PromptCollectionTestInvocationOutput(collection=[self.value]) return PromptCollectionTestInvocationOutput(collection=self.value)