fix(nodes): restore type annotations for InvocationContext

This commit is contained in:
psychedelicious 2024-02-05 17:16:35 +11:00
parent 281c334531
commit 4ce21087d3
25 changed files with 158 additions and 143 deletions

View File

@ -174,7 +174,7 @@ class ResizeInvocation(BaseInvocation):
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
def invoke(self, context):
def invoke(self, context: InvocationContext):
pass
```
@ -203,7 +203,7 @@ class ResizeInvocation(BaseInvocation):
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
pass
```
@ -229,7 +229,7 @@ class ResizeInvocation(BaseInvocation):
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the input image as a PIL image
image = context.images.get_pil(self.image.image_name)

View File

@ -5,6 +5,7 @@ import numpy as np
from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.primitives import IntegerCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from .baseinvocation import BaseInvocation, invocation
@ -27,7 +28,7 @@ class RangeInvocation(BaseInvocation):
raise ValueError("stop must be greater than start")
return v
def invoke(self, context) -> IntegerCollectionOutput:
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
@ -45,7 +46,7 @@ class RangeOfSizeInvocation(BaseInvocation):
size: int = InputField(default=1, gt=0, description="The number of values")
step: int = InputField(default=1, description="The step of the range")
def invoke(self, context) -> IntegerCollectionOutput:
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(
collection=list(range(self.start, self.start + (self.step * self.size), self.step))
)
@ -72,6 +73,6 @@ class RandomRangeInvocation(BaseInvocation):
description="The seed for the RNG (omit for random)",
)
def invoke(self, context) -> IntegerCollectionOutput:
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
rng = np.random.default_rng(self.seed)
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
@ -31,9 +32,6 @@ from .baseinvocation import (
)
from .model import ClipField
if TYPE_CHECKING:
from invokeai.app.services.shared.invocation_context import InvocationContext
# unconditioned: Optional[torch.Tensor]
@ -65,7 +63,7 @@ class CompelInvocation(BaseInvocation):
)
@torch.no_grad()
def invoke(self, context) -> ConditioningOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
@ -148,7 +146,7 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
def run_clip_compel(
self,
context: "InvocationContext",
context: InvocationContext,
clip_field: ClipField,
prompt: str,
get_pooled: bool,
@ -288,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
@torch.no_grad()
def invoke(self, context) -> ConditioningOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
)
@ -373,7 +371,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
@torch.no_grad()
def invoke(self, context) -> ConditioningOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
@ -418,7 +416,7 @@ class ClipSkipInvocation(BaseInvocation):
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context) -> ClipSkipInvocationOutput:
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
return ClipSkipInvocationOutput(
clip=self.clip,

View File

@ -28,6 +28,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.model_management.models.base import BaseModelType
@ -119,7 +120,7 @@ class ControlNetInvocation(BaseInvocation):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context) -> ControlOutput:
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
@ -143,7 +144,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata):
# superclass just passes through image without processing
return image
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.images.get_pil(self.image.image_name)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)

View File

@ -7,6 +7,7 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithMetadata
@ -19,7 +20,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to inpaint")
mask: ImageField = InputField(description="The mask to use when inpainting")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
mask = context.images.get_pil(self.mask.image_name)

View File

@ -1,7 +1,7 @@
import math
import re
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypedDict
from typing import Optional, TypedDict
import cv2
import numpy as np
@ -19,8 +19,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
if TYPE_CHECKING:
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -176,7 +174,7 @@ def prepare_faces_list(
def generate_face_box_mask(
context: "InvocationContext",
context: InvocationContext,
minimum_confidence: float,
x_offset: float,
y_offset: float,
@ -275,7 +273,7 @@ def generate_face_box_mask(
def extract_face(
context: "InvocationContext",
context: InvocationContext,
image: ImageType,
face: FaceResultData,
padding: int,
@ -356,7 +354,7 @@ def extract_face(
def get_faces_list(
context: "InvocationContext",
context: InvocationContext,
image: ImageType,
should_chunk: bool,
minimum_confidence: float,
@ -458,7 +456,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
)
def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]:
def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]:
all_faces = get_faces_list(
context=context,
image=image,
@ -485,7 +483,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
return face_data
def invoke(self, context) -> FaceOffOutput:
def invoke(self, context: InvocationContext) -> FaceOffOutput:
image = context.images.get_pil(self.image.image_name)
result = self.faceoff(context=context, image=image)
@ -543,7 +541,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")')
return v
def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult:
def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult:
all_faces = get_faces_list(
context=context,
image=image,
@ -600,7 +598,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
mask=mask_pil,
)
def invoke(self, context) -> FaceMaskOutput:
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
image = context.images.get_pil(self.image.image_name)
result = self.facemask(context=context, image=image)
@ -633,7 +631,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
)
def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType:
def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType:
image = image.copy()
all_faces = get_faces_list(
@ -674,7 +672,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
return image
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
result_image = self.faceidentifier(context=context, image=image)

View File

@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
@ -34,7 +35,7 @@ class ShowImageInvocation(BaseInvocation):
image: ImageField = InputField(description="The image to show")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image.show()
@ -62,7 +63,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata):
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple())
image_dto = context.images.save(image=image)
@ -86,7 +87,7 @@ class ImageCropInvocation(BaseInvocation, WithMetadata):
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
@ -125,7 +126,7 @@ class CenterPadCropInvocation(BaseInvocation):
description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)",
)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
# Calculate and create new image dimensions
@ -161,7 +162,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata):
y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
crop: bool = InputField(default=False, description="Crop to base image dimensions")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.images.get_pil(self.base_image.image_name)
image = context.images.get_pil(self.image.image_name)
mask = None
@ -201,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to create the mask from")
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_mask = image.split()[-1]
@ -226,7 +227,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata):
image1: ImageField = InputField(description="The first image to multiply")
image2: ImageField = InputField(description="The second image to multiply")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.images.get_pil(self.image1.image_name)
image2 = context.images.get_pil(self.image2.image_name)
@ -253,7 +254,7 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to get the channel from")
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
channel_image = image.getchannel(self.channel)
@ -279,7 +280,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to convert")
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
converted_image = image.convert(self.mode)
@ -304,7 +305,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata):
# Metadata
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
blur = (
@ -338,7 +339,7 @@ class UnsharpMaskInvocation(BaseInvocation, WithMetadata):
def array_from_pil(self, img):
return numpy.array(img) / 255
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
mode = image.mode
@ -401,7 +402,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata):
height: int = InputField(default=512, gt=0, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -434,7 +435,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata):
)
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -465,7 +466,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata):
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
@ -492,7 +493,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
@ -517,7 +518,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to check")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
logger = context.logger
@ -553,7 +554,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to check")
text: str = InputField(default="InvokeAI", description="Watermark text")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
new_image = InvisibleWatermark.add_watermark(image, self.text)
image_dto = context.images.save(image=new_image)
@ -579,7 +580,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata):
description="Second threshold for the hysteresis procedure in Canny edge detection"
)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = context.images.get_pil(self.image.image_name).convert("L")
npimg = numpy.asarray(mask, dtype=numpy.uint8)
@ -613,7 +614,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata):
mask1: ImageField = InputField(description="The first mask to combine")
mask2: ImageField = InputField(description="The second image to combine")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
mask1 = context.images.get_pil(self.mask1.image_name).convert("L")
mask2 = context.images.get_pil(self.mask2.image_name).convert("L")
@ -642,7 +643,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata):
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_init_mask = None
if self.mask is not None:
pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L")
@ -741,7 +742,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to adjust")
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.images.get_pil(self.image.image_name)
# Convert image to HSV color space
@ -831,7 +832,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.images.get_pil(self.image.image_name)
# extract the channel and mode from the input and reference tuple
@ -888,7 +889,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.")
invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.images.get_pil(self.image.image_name)
# extract the channel and mode from the input and reference tuple
@ -931,7 +932,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description=FieldDescriptions.image)
board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None)
@ -953,7 +954,7 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description=FieldDescriptions.image)
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image_dto = context.images.get_dto(self.image.image_name)
image_dto = context.images.update(

View File

@ -8,6 +8,7 @@ from PIL import Image, ImageOps
from invokeai.app.invocations.fields import ColorField, ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA
@ -129,7 +130,7 @@ class InfillColorInvocation(BaseInvocation, WithMetadata):
description="The color to use to infill",
)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
@ -155,7 +156,7 @@ class InfillTileInvocation(BaseInvocation, WithMetadata):
description="The seed to use for tile generation (omit for random)",
)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
@ -176,7 +177,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name).convert("RGBA")
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -213,7 +214,7 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to infill")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
infilled = infill_lama(image.copy())
@ -229,7 +230,7 @@ class CV2InfillInvocation(BaseInvocation, WithMetadata):
image: ImageField = InputField(description="The image to infill")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
infilled = infill_cv2(image.copy())

View File

@ -13,6 +13,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
@ -92,7 +93,7 @@ class IPAdapterInvocation(BaseInvocation):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context) -> IPAdapterOutput:
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_info(
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter

View File

@ -3,7 +3,7 @@
import math
from contextlib import ExitStack
from functools import singledispatchmethod
from typing import TYPE_CHECKING, List, Literal, Optional, Union
from typing import List, Literal, Optional, Union
import einops
import numpy as np
@ -42,6 +42,7 @@ from invokeai.app.invocations.primitives import (
LatentsOutput,
)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
@ -70,9 +71,6 @@ from .baseinvocation import (
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
if TYPE_CHECKING:
from invokeai.app.services.shared.invocation_context import InvocationContext
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -177,7 +175,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
def get_scheduler(
context: "InvocationContext",
context: InvocationContext,
scheduler_info: ModelInfo,
scheduler_name: str,
seed: int,
@ -300,7 +298,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def get_conditioning_data(
self,
context: "InvocationContext",
context: InvocationContext,
scheduler,
unet,
seed,
@ -369,7 +367,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data(
self,
context: "InvocationContext",
context: InvocationContext,
control_input: Union[ControlField, List[ControlField]],
latents_shape: List[int],
exit_stack: ExitStack,
@ -442,7 +440,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_ip_adapter_data(
self,
context: "InvocationContext",
context: InvocationContext,
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
conditioning_data: ConditioningData,
exit_stack: ExitStack,
@ -509,7 +507,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
def run_t2i_adapters(
self,
context: "InvocationContext",
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
do_classifier_free_guidance: bool,
@ -618,7 +616,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps, init_timestep
def prep_inpaint_mask(self, context: "InvocationContext", latents):
def prep_inpaint_mask(self, context: InvocationContext, latents):
if self.denoise_mask is None:
return None, None

View File

@ -7,6 +7,7 @@ from pydantic import ValidationInfo, field_validator
from invokeai.app.invocations.fields import FieldDescriptions, InputField
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
@ -18,7 +19,7 @@ class AddInvocation(BaseInvocation):
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context) -> IntegerOutput:
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=self.a + self.b)
@ -29,7 +30,7 @@ class SubtractInvocation(BaseInvocation):
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context) -> IntegerOutput:
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=self.a - self.b)
@ -40,7 +41,7 @@ class MultiplyInvocation(BaseInvocation):
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context) -> IntegerOutput:
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=self.a * self.b)
@ -51,7 +52,7 @@ class DivideInvocation(BaseInvocation):
a: int = InputField(default=0, description=FieldDescriptions.num_1)
b: int = InputField(default=0, description=FieldDescriptions.num_2)
def invoke(self, context) -> IntegerOutput:
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=int(self.a / self.b))
@ -69,7 +70,7 @@ class RandomIntInvocation(BaseInvocation):
low: int = InputField(default=0, description=FieldDescriptions.inclusive_low)
high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high)
def invoke(self, context) -> IntegerOutput:
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=np.random.randint(self.low, self.high))
@ -88,7 +89,7 @@ class RandomFloatInvocation(BaseInvocation):
high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high)
decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places)
def invoke(self, context) -> FloatOutput:
def invoke(self, context: InvocationContext) -> FloatOutput:
random_float = np.random.uniform(self.low, self.high)
rounded_float = round(random_float, self.decimals)
return FloatOutput(value=rounded_float)
@ -110,7 +111,7 @@ class FloatToIntegerInvocation(BaseInvocation):
default="Nearest", description="The method to use for rounding"
)
def invoke(self, context) -> IntegerOutput:
def invoke(self, context: InvocationContext) -> IntegerOutput:
if self.method == "Nearest":
return IntegerOutput(value=round(self.value / self.multiple) * self.multiple)
elif self.method == "Floor":
@ -128,7 +129,7 @@ class RoundInvocation(BaseInvocation):
value: float = InputField(default=0, description="The float value")
decimals: int = InputField(default=0, description="The number of decimal places")
def invoke(self, context) -> FloatOutput:
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(value=round(self.value, self.decimals))
@ -196,7 +197,7 @@ class IntegerMathInvocation(BaseInvocation):
raise ValueError("Result of exponentiation is not an integer")
return v
def invoke(self, context) -> IntegerOutput:
def invoke(self, context: InvocationContext) -> IntegerOutput:
# Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9
if self.operation == "ADD":
return IntegerOutput(value=self.a + self.b)
@ -270,7 +271,7 @@ class FloatMathInvocation(BaseInvocation):
raise ValueError("Root operation resulted in a complex number")
return v
def invoke(self, context) -> FloatOutput:
def invoke(self, context: InvocationContext) -> FloatOutput:
# Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9
if self.operation == "ADD":
return FloatOutput(value=self.a + self.b)

View File

@ -20,6 +20,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from ...version import __version__
@ -64,7 +65,7 @@ class MetadataItemInvocation(BaseInvocation):
label: str = InputField(description=FieldDescriptions.metadata_item_label)
value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any)
def invoke(self, context) -> MetadataItemOutput:
def invoke(self, context: InvocationContext) -> MetadataItemOutput:
return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value))
@ -81,7 +82,7 @@ class MetadataInvocation(BaseInvocation):
description=FieldDescriptions.metadata_item_polymorphic
)
def invoke(self, context) -> MetadataOutput:
def invoke(self, context: InvocationContext) -> MetadataOutput:
if isinstance(self.items, MetadataItemField):
# single metadata item
data = {self.items.label: self.items.value}
@ -100,7 +101,7 @@ class MergeMetadataInvocation(BaseInvocation):
collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection)
def invoke(self, context) -> MetadataOutput:
def invoke(self, context: InvocationContext) -> MetadataOutput:
data = {}
for item in self.collection:
data.update(item.model_dump())
@ -218,7 +219,7 @@ class CoreMetadataInvocation(BaseInvocation):
description="The start value used for refiner denoising",
)
def invoke(self, context) -> MetadataOutput:
def invoke(self, context: InvocationContext) -> MetadataOutput:
"""Collects and outputs a CoreMetadata object"""
return MetadataOutput(

View File

@ -4,6 +4,7 @@ from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from ...backend.model_management import BaseModelType, ModelType, SubModelType
@ -109,7 +110,7 @@ class MainModelLoaderInvocation(BaseInvocation):
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
# TODO: precision?
def invoke(self, context) -> ModelLoaderOutput:
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
@ -221,7 +222,7 @@ class LoraLoaderInvocation(BaseInvocation):
title="CLIP",
)
def invoke(self, context) -> LoraLoaderOutput:
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
@ -310,7 +311,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
title="CLIP 2",
)
def invoke(self, context) -> SDXLLoraLoaderOutput:
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
@ -393,7 +394,7 @@ class VaeLoaderInvocation(BaseInvocation):
title="VAE",
)
def invoke(self, context) -> VAEOutput:
def invoke(self, context: InvocationContext) -> VAEOutput:
base_model = self.vae_model.base_model
model_name = self.vae_model.model_name
model_type = ModelType.Vae
@ -448,7 +449,7 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
def invoke(self, context) -> SeamlessModeOutput:
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
unet = copy.deepcopy(self.unet)
vae = copy.deepcopy(self.vae)
@ -484,6 +485,6 @@ class FreeUInvocation(BaseInvocation):
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
def invoke(self, context) -> UNetOutput:
def invoke(self, context: InvocationContext) -> UNetOutput:
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
return UNetOutput(unet=self.unet)

View File

@ -5,6 +5,7 @@ import torch
from pydantic import field_validator
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from ...backend.util.devices import choose_torch_device, torch_dtype
@ -112,7 +113,7 @@ class NoiseInvocation(BaseInvocation):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)
def invoke(self, context) -> NoiseOutput:
def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,

View File

@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation):
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
def invoke(self, context) -> ConditioningOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.model_dump(),
)
@ -201,7 +201,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
def invoke(self, context) -> LatentsOutput:
def invoke(self, context: InvocationContext) -> LatentsOutput:
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
)
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
if self.vae.vae.submodel != SubModelType.VaeDecoder:
@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
)
def invoke(self, context) -> ONNXModelLoaderOutput:
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.ONNX

View File

@ -40,6 +40,7 @@ from easing_functions import (
from matplotlib.ticker import MaxNLocator
from invokeai.app.invocations.primitives import FloatCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField
@ -62,7 +63,7 @@ class FloatLinearRangeInvocation(BaseInvocation):
description="number of values to interpolate over (including start and stop)",
)
def invoke(self, context) -> FloatCollectionOutput:
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps))
return FloatCollectionOutput(collection=param_list)
@ -130,7 +131,7 @@ class StepParamEasingInvocation(BaseInvocation):
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
show_easing_plot: bool = InputField(default=False, description="show easing plot")
def invoke(self, context) -> FloatCollectionOutput:
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
# start_step = int(np.floor(self.num_steps * self.start_step_percent))

View File

@ -17,6 +17,7 @@ from invokeai.app.invocations.fields import (
UIComponent,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import (
BaseInvocation,
@ -59,7 +60,7 @@ class BooleanInvocation(BaseInvocation):
value: bool = InputField(default=False, description="The boolean value")
def invoke(self, context) -> BooleanOutput:
def invoke(self, context: InvocationContext) -> BooleanOutput:
return BooleanOutput(value=self.value)
@ -75,7 +76,7 @@ class BooleanCollectionInvocation(BaseInvocation):
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
def invoke(self, context) -> BooleanCollectionOutput:
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
return BooleanCollectionOutput(collection=self.collection)
@ -108,7 +109,7 @@ class IntegerInvocation(BaseInvocation):
value: int = InputField(default=0, description="The integer value")
def invoke(self, context) -> IntegerOutput:
def invoke(self, context: InvocationContext) -> IntegerOutput:
return IntegerOutput(value=self.value)
@ -124,7 +125,7 @@ class IntegerCollectionInvocation(BaseInvocation):
collection: list[int] = InputField(default=[], description="The collection of integer values")
def invoke(self, context) -> IntegerCollectionOutput:
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
return IntegerCollectionOutput(collection=self.collection)
@ -155,7 +156,7 @@ class FloatInvocation(BaseInvocation):
value: float = InputField(default=0.0, description="The float value")
def invoke(self, context) -> FloatOutput:
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(value=self.value)
@ -171,7 +172,7 @@ class FloatCollectionInvocation(BaseInvocation):
collection: list[float] = InputField(default=[], description="The collection of float values")
def invoke(self, context) -> FloatCollectionOutput:
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
return FloatCollectionOutput(collection=self.collection)
@ -202,7 +203,7 @@ class StringInvocation(BaseInvocation):
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
def invoke(self, context) -> StringOutput:
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(value=self.value)
@ -218,7 +219,7 @@ class StringCollectionInvocation(BaseInvocation):
collection: list[str] = InputField(default=[], description="The collection of string values")
def invoke(self, context) -> StringCollectionOutput:
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
return StringCollectionOutput(collection=self.collection)
@ -261,7 +262,7 @@ class ImageInvocation(
image: ImageField = InputField(description="The image to load")
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
return ImageOutput(
@ -283,7 +284,7 @@ class ImageCollectionInvocation(BaseInvocation):
collection: list[ImageField] = InputField(description="The collection of image values")
def invoke(self, context) -> ImageCollectionOutput:
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
return ImageCollectionOutput(collection=self.collection)
@ -346,7 +347,7 @@ class LatentsInvocation(BaseInvocation):
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
def invoke(self, context) -> LatentsOutput:
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.latents.get(self.latents.latents_name)
return LatentsOutput.build(self.latents.latents_name, latents)
@ -366,7 +367,7 @@ class LatentsCollectionInvocation(BaseInvocation):
description="The collection of latents tensors",
)
def invoke(self, context) -> LatentsCollectionOutput:
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
return LatentsCollectionOutput(collection=self.collection)
@ -397,7 +398,7 @@ class ColorInvocation(BaseInvocation):
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
def invoke(self, context) -> ColorOutput:
def invoke(self, context: InvocationContext) -> ColorOutput:
return ColorOutput(color=self.color)
@ -438,7 +439,7 @@ class ConditioningInvocation(BaseInvocation):
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
def invoke(self, context) -> ConditioningOutput:
def invoke(self, context: InvocationContext) -> ConditioningOutput:
return ConditioningOutput(conditioning=self.conditioning)
@ -457,7 +458,7 @@ class ConditioningCollectionInvocation(BaseInvocation):
description="The collection of conditioning tensors",
)
def invoke(self, context) -> ConditioningCollectionOutput:
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
return ConditioningCollectionOutput(collection=self.collection)

View File

@ -6,6 +6,7 @@ from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPrompt
from pydantic import field_validator
from invokeai.app.invocations.primitives import StringCollectionOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, UIComponent
@ -29,7 +30,7 @@ class DynamicPromptInvocation(BaseInvocation):
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
def invoke(self, context) -> StringCollectionOutput:
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
@ -91,7 +92,7 @@ class PromptsFromFileInvocation(BaseInvocation):
break
return prompts
def invoke(self, context) -> StringCollectionOutput:
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
prompts = self.promptsFromFile(
self.file_path,
self.pre_prompt,

View File

@ -1,4 +1,5 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (
@ -38,7 +39,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
)
# TODO: precision?
def invoke(self, context) -> SDXLModelLoaderOutput:
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main
@ -127,7 +128,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
)
# TODO: precision?
def invoke(self, context) -> SDXLRefinerModelLoaderOutput:
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Main

View File

@ -2,6 +2,8 @@
import re
from invokeai.app.services.shared.invocation_context import InvocationContext
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -32,7 +34,7 @@ class StringSplitNegInvocation(BaseInvocation):
string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea)
def invoke(self, context) -> StringPosNegOutput:
def invoke(self, context: InvocationContext) -> StringPosNegOutput:
p_string = ""
n_string = ""
brackets_depth = 0
@ -76,7 +78,7 @@ class StringSplitInvocation(BaseInvocation):
default="", description="Delimiter to spilt with. blank will split on the first whitespace"
)
def invoke(self, context) -> String2Output:
def invoke(self, context: InvocationContext) -> String2Output:
result = self.string.split(self.delimiter, 1)
if len(result) == 2:
part1, part2 = result
@ -94,7 +96,7 @@ class StringJoinInvocation(BaseInvocation):
string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea)
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
def invoke(self, context) -> StringOutput:
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
@ -106,7 +108,7 @@ class StringJoinThreeInvocation(BaseInvocation):
string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea)
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
def invoke(self, context) -> StringOutput:
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or "")))
@ -125,7 +127,7 @@ class StringReplaceInvocation(BaseInvocation):
default=False, description="Use search string as a regex expression (non regex is case insensitive)"
)
def invoke(self, context) -> StringOutput:
def invoke(self, context: InvocationContext) -> StringOutput:
pattern = self.search_string or ""
new_string = self.string or ""
if len(pattern) > 0:

View File

@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_management.models.base import BaseModelType
@ -89,7 +90,7 @@ class T2IAdapterInvocation(BaseInvocation):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context) -> T2IAdapterOutput:
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
return T2IAdapterOutput(
t2i_adapter=T2IAdapterField(
image=self.image,

View File

@ -13,6 +13,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.tiles.tiles import (
calc_tiles_even_split,
calc_tiles_min_overlap,
@ -56,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation):
description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount",
)
def invoke(self, context) -> CalculateImageTilesOutput:
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_with_overlap(
image_height=self.image_height,
image_width=self.image_width,
@ -99,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
description="The overlap, in pixels, between adjacent tiles.",
)
def invoke(self, context) -> CalculateImageTilesOutput:
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_even_split(
image_height=self.image_height,
image_width=self.image_width,
@ -129,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
def invoke(self, context) -> CalculateImageTilesOutput:
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
tiles = calc_tiles_min_overlap(
image_height=self.image_height,
image_width=self.image_width,
@ -174,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation):
tile: Tile = InputField(description="The tile to split into properties.")
def invoke(self, context) -> TileToPropertiesOutput:
def invoke(self, context: InvocationContext) -> TileToPropertiesOutput:
return TileToPropertiesOutput(
coords_left=self.tile.coords.left,
coords_right=self.tile.coords.right,
@ -211,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation):
image: ImageField = InputField(description="The tile image.")
tile: Tile = InputField(description="The tile properties.")
def invoke(self, context) -> PairTileImageOutput:
def invoke(self, context: InvocationContext) -> PairTileImageOutput:
return PairTileImageOutput(
tile_with_image=TileWithImage(
tile=self.tile,
@ -247,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
)
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
images = [twi.image for twi in self.tiles_with_images]
tiles = [twi.tile for twi in self.tiles_with_images]

View File

@ -10,6 +10,7 @@ from pydantic import ConfigDict
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
@ -42,7 +43,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
model_config = ConfigDict(protected_namespaces=())
def invoke(self, context) -> ImageOutput:
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
models_path = context.config.get().models_path

View File

@ -17,6 +17,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import uuid_string
# in 3.10 this would be "from types import NoneType"
@ -201,7 +202,7 @@ class GraphInvocation(BaseInvocation):
# TODO: figure out how to create a default here
graph: "Graph" = InputField(description="The graph to run", default=None)
def invoke(self, context) -> GraphInvocationOutput:
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
"""Invoke with provided services and return outputs."""
return GraphInvocationOutput()
@ -227,7 +228,7 @@ class IterateInvocation(BaseInvocation):
)
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
def invoke(self, context) -> IterateInvocationOutput:
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values"""
return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection))
@ -254,7 +255,7 @@ class CollectInvocation(BaseInvocation):
description="The collection, will be provided on execution", default=[], ui_hidden=True
)
def invoke(self, context) -> CollectInvocationOutput:
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
"""Invoke with provided services and return outputs."""
return CollectInvocationOutput(collection=copy.copy(self.collection))

View File

@ -8,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.fields import InputField, OutputField
from invokeai.app.invocations.image import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
# Define test invocations before importing anything that uses invocations
@ -20,7 +21,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput):
class ListPassThroughInvocation(BaseInvocation):
collection: list[ImageField] = InputField(default=[])
def invoke(self, context) -> ListPassThroughInvocationOutput:
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
return ListPassThroughInvocationOutput(collection=self.collection)
@ -33,13 +34,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput):
class PromptTestInvocation(BaseInvocation):
prompt: str = InputField(default="")
def invoke(self, context) -> PromptTestInvocationOutput:
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
return PromptTestInvocationOutput(prompt=self.prompt)
@invocation("test_error", version="1.0.0")
class ErrorInvocation(BaseInvocation):
def invoke(self, context) -> PromptTestInvocationOutput:
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
raise Exception("This invocation is supposed to fail")
@ -53,7 +54,7 @@ class TextToImageTestInvocation(BaseInvocation):
prompt: str = InputField(default="")
prompt2: str = InputField(default="")
def invoke(self, context) -> ImageTestInvocationOutput:
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@ -62,7 +63,7 @@ class ImageToImageTestInvocation(BaseInvocation):
prompt: str = InputField(default="")
image: Union[ImageField, None] = InputField(default=None)
def invoke(self, context) -> ImageTestInvocationOutput:
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
@ -75,7 +76,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
class PromptCollectionTestInvocation(BaseInvocation):
collection: list[str] = InputField()
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
@ -88,7 +89,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput):
class AnyTypeTestInvocation(BaseInvocation):
value: Any = InputField(default=None)
def invoke(self, context) -> AnyTypeTestInvocationOutput:
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
return AnyTypeTestInvocationOutput(value=self.value)
@ -96,7 +97,7 @@ class AnyTypeTestInvocation(BaseInvocation):
class PolymorphicStringTestInvocation(BaseInvocation):
value: Union[str, list[str]] = InputField(default="")
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
if isinstance(self.value, str):
return PromptCollectionTestInvocationOutput(collection=[self.value])
return PromptCollectionTestInvocationOutput(collection=self.value)