mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): restore type annotations for InvocationContext
This commit is contained in:
parent
281c334531
commit
4ce21087d3
@ -174,7 +174,7 @@ class ResizeInvocation(BaseInvocation):
|
||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
|
||||
def invoke(self, context):
|
||||
def invoke(self, context: InvocationContext):
|
||||
pass
|
||||
```
|
||||
|
||||
@ -203,7 +203,7 @@ class ResizeInvocation(BaseInvocation):
|
||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pass
|
||||
```
|
||||
|
||||
@ -229,7 +229,7 @@ class ResizeInvocation(BaseInvocation):
|
||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load the input image as a PIL image
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
|
@ -5,6 +5,7 @@ import numpy as np
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
@ -27,7 +28,7 @@ class RangeInvocation(BaseInvocation):
|
||||
raise ValueError("stop must be greater than start")
|
||||
return v
|
||||
|
||||
def invoke(self, context) -> IntegerCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
|
||||
|
||||
@ -45,7 +46,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
size: int = InputField(default=1, gt=0, description="The number of values")
|
||||
step: int = InputField(default=1, description="The step of the range")
|
||||
|
||||
def invoke(self, context) -> IntegerCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(
|
||||
collection=list(range(self.start, self.start + (self.step * self.size), self.step))
|
||||
)
|
||||
@ -72,6 +73,6 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
description="The seed for the RNG (omit for random)",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> IntegerCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
rng = np.random.default_rng(self.seed)
|
||||
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
@ -31,10 +32,7 @@ from .baseinvocation import (
|
||||
)
|
||||
from .model import ClipField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
|
||||
# class ConditioningAlgo(str, Enum):
|
||||
@ -65,7 +63,7 @@ class CompelInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context) -> ConditioningOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
|
||||
@ -148,7 +146,7 @@ class CompelInvocation(BaseInvocation):
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
clip_field: ClipField,
|
||||
prompt: str,
|
||||
get_pooled: bool,
|
||||
@ -288,7 +286,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context) -> ConditioningOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||
)
|
||||
@ -373,7 +371,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context) -> ConditioningOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
# TODO: if there will appear lora for refiner - write proper prefix
|
||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||
|
||||
@ -418,7 +416,7 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context) -> ClipSkipInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
return ClipSkipInvocationOutput(
|
||||
clip=self.clip,
|
||||
|
@ -28,6 +28,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
from invokeai.backend.model_management.models.base import BaseModelType
|
||||
@ -119,7 +120,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context) -> ControlOutput:
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
@ -143,7 +144,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raw_image = context.images.get_pil(self.image.image_name)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
@ -7,6 +7,7 @@ from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithMetadata
|
||||
@ -19,7 +20,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
||||
image: ImageField = InputField(description="The image to inpaint")
|
||||
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
mask = context.images.get_pil(self.mask.image_name)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import math
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -19,9 +19,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@invocation_output("face_mask_output")
|
||||
@ -176,7 +174,7 @@ def prepare_faces_list(
|
||||
|
||||
|
||||
def generate_face_box_mask(
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
minimum_confidence: float,
|
||||
x_offset: float,
|
||||
y_offset: float,
|
||||
@ -275,7 +273,7 @@ def generate_face_box_mask(
|
||||
|
||||
|
||||
def extract_face(
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
image: ImageType,
|
||||
face: FaceResultData,
|
||||
padding: int,
|
||||
@ -356,7 +354,7 @@ def extract_face(
|
||||
|
||||
|
||||
def get_faces_list(
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
image: ImageType,
|
||||
should_chunk: bool,
|
||||
minimum_confidence: float,
|
||||
@ -458,7 +456,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
||||
)
|
||||
|
||||
def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]:
|
||||
def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]:
|
||||
all_faces = get_faces_list(
|
||||
context=context,
|
||||
image=image,
|
||||
@ -485,7 +483,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
return face_data
|
||||
|
||||
def invoke(self, context) -> FaceOffOutput:
|
||||
def invoke(self, context: InvocationContext) -> FaceOffOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
result = self.faceoff(context=context, image=image)
|
||||
|
||||
@ -543,7 +541,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")')
|
||||
return v
|
||||
|
||||
def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult:
|
||||
def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult:
|
||||
all_faces = get_faces_list(
|
||||
context=context,
|
||||
image=image,
|
||||
@ -600,7 +598,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
mask=mask_pil,
|
||||
)
|
||||
|
||||
def invoke(self, context) -> FaceMaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
result = self.facemask(context=context, image=image)
|
||||
|
||||
@ -633,7 +631,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
||||
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
||||
)
|
||||
|
||||
def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType:
|
||||
def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType:
|
||||
image = image.copy()
|
||||
|
||||
all_faces = get_faces_list(
|
||||
@ -674,7 +672,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
return image
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
result_image = self.faceidentifier(context=context, image=image)
|
||||
|
||||
|
@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
@ -34,7 +35,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
|
||||
image: ImageField = InputField(description="The image to show")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image.show()
|
||||
|
||||
@ -62,7 +63,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata):
|
||||
mode: Literal["RGB", "RGBA"] = InputField(default="RGB", description="The mode of the image")
|
||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color of the image")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = Image.new(mode=self.mode, size=(self.width, self.height), color=self.color.tuple())
|
||||
|
||||
image_dto = context.images.save(image=image)
|
||||
@ -86,7 +87,7 @@ class ImageCropInvocation(BaseInvocation, WithMetadata):
|
||||
width: int = InputField(default=512, gt=0, description="The width of the crop rectangle")
|
||||
height: int = InputField(default=512, gt=0, description="The height of the crop rectangle")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
|
||||
@ -125,7 +126,7 @@ class CenterPadCropInvocation(BaseInvocation):
|
||||
description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Calculate and create new image dimensions
|
||||
@ -161,7 +162,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata):
|
||||
y: int = InputField(default=0, description="The top y coordinate at which to paste the image")
|
||||
crop: bool = InputField(default=False, description="Crop to base image dimensions")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.images.get_pil(self.base_image.image_name)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
mask = None
|
||||
@ -201,7 +202,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
invert: bool = InputField(default=False, description="Whether or not to invert the mask")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
@ -226,7 +227,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithMetadata):
|
||||
image1: ImageField = InputField(description="The first image to multiply")
|
||||
image2: ImageField = InputField(description="The second image to multiply")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.images.get_pil(self.image1.image_name)
|
||||
image2 = context.images.get_pil(self.image2.image_name)
|
||||
|
||||
@ -253,7 +254,7 @@ class ImageChannelInvocation(BaseInvocation, WithMetadata):
|
||||
image: ImageField = InputField(description="The image to get the channel from")
|
||||
channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
channel_image = image.getchannel(self.channel)
|
||||
@ -279,7 +280,7 @@ class ImageConvertInvocation(BaseInvocation, WithMetadata):
|
||||
image: ImageField = InputField(description="The image to convert")
|
||||
mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
converted_image = image.convert(self.mode)
|
||||
@ -304,7 +305,7 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata):
|
||||
# Metadata
|
||||
blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
blur = (
|
||||
@ -338,7 +339,7 @@ class UnsharpMaskInvocation(BaseInvocation, WithMetadata):
|
||||
def array_from_pil(self, img):
|
||||
return numpy.array(img) / 255
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
mode = image.mode
|
||||
|
||||
@ -401,7 +402,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata):
|
||||
height: int = InputField(default=512, gt=0, description="The height to resize to (px)")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
@ -434,7 +435,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
@ -465,7 +466,7 @@ class ImageLerpInvocation(BaseInvocation, WithMetadata):
|
||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum output value")
|
||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum output value")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
@ -492,7 +493,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
|
||||
min: int = InputField(default=0, ge=0, le=255, description="The minimum input value")
|
||||
max: int = InputField(default=255, ge=0, le=255, description="The maximum input value")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
@ -517,7 +518,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
logger = context.logger
|
||||
@ -553,7 +554,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
text: str = InputField(default="InvokeAI", description="Watermark text")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
new_image = InvisibleWatermark.add_watermark(image, self.text)
|
||||
image_dto = context.images.save(image=new_image)
|
||||
@ -579,7 +580,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata):
|
||||
description="Second threshold for the hysteresis procedure in Canny edge detection"
|
||||
)
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = context.images.get_pil(self.image.image_name).convert("L")
|
||||
|
||||
npimg = numpy.asarray(mask, dtype=numpy.uint8)
|
||||
@ -613,7 +614,7 @@ class MaskCombineInvocation(BaseInvocation, WithMetadata):
|
||||
mask1: ImageField = InputField(description="The first mask to combine")
|
||||
mask2: ImageField = InputField(description="The second image to combine")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask1 = context.images.get_pil(self.mask1.image_name).convert("L")
|
||||
mask2 = context.images.get_pil(self.mask2.image_name).convert("L")
|
||||
|
||||
@ -642,7 +643,7 @@ class ColorCorrectInvocation(BaseInvocation, WithMetadata):
|
||||
mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction")
|
||||
mask_blur_radius: float = InputField(default=8, description="Mask blur radius")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_init_mask = None
|
||||
if self.mask is not None:
|
||||
pil_init_mask = context.images.get_pil(self.mask.image_name).convert("L")
|
||||
@ -741,7 +742,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata):
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Convert image to HSV color space
|
||||
@ -831,7 +832,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
|
||||
channel: COLOR_CHANNELS = InputField(description="Which channel to adjust")
|
||||
offset: int = InputField(default=0, ge=-255, le=255, description="The amount to adjust the channel by")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# extract the channel and mode from the input and reference tuple
|
||||
@ -888,7 +889,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
|
||||
scale: float = InputField(default=1.0, ge=0.0, description="The amount to scale the channel by.")
|
||||
invert_channel: bool = InputField(default=False, description="Invert the channel after scaling")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# extract the channel and mode from the input and reference tuple
|
||||
@ -931,7 +932,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata):
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
board: BoardField = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_dto = context.images.save(image=image, board_id=self.board.board_id if self.board else None)
|
||||
@ -953,7 +954,7 @@ class LinearUIOutputInvocation(BaseInvocation, WithMetadata):
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
board: Optional[BoardField] = InputField(default=None, description=FieldDescriptions.board, input=Input.Direct)
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image_dto = context.images.get_dto(self.image.image_name)
|
||||
|
||||
image_dto = context.images.update(
|
||||
|
@ -8,6 +8,7 @@ from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
@ -129,7 +130,7 @@ class InfillColorInvocation(BaseInvocation, WithMetadata):
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
@ -155,7 +156,7 @@ class InfillTileInvocation(BaseInvocation, WithMetadata):
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||
@ -176,7 +177,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
@ -213,7 +214,7 @@ class LaMaInfillInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
@ -229,7 +230,7 @@ class CV2InfillInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = infill_cv2(image.copy())
|
||||
|
@ -13,6 +13,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
|
||||
@ -92,7 +93,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context) -> IPAdapterOutput:
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_info(
|
||||
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
||||
|
@ -3,7 +3,7 @@
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@ -42,6 +42,7 @@ from invokeai.app.invocations.primitives import (
|
||||
LatentsOutput,
|
||||
)
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
@ -70,9 +71,6 @@ from .baseinvocation import (
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
@ -177,7 +175,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
@ -300,7 +298,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
seed,
|
||||
@ -369,7 +367,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
control_input: Union[ControlField, List[ControlField]],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
@ -442,7 +440,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def prep_ip_adapter_data(
|
||||
self,
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||
conditioning_data: ConditioningData,
|
||||
exit_stack: ExitStack,
|
||||
@ -509,7 +507,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def run_t2i_adapters(
|
||||
self,
|
||||
context: "InvocationContext",
|
||||
context: InvocationContext,
|
||||
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||
latents_shape: list[int],
|
||||
do_classifier_free_guidance: bool,
|
||||
@ -618,7 +616,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_inpaint_mask(self, context: "InvocationContext", latents):
|
||||
def prep_inpaint_mask(self, context: InvocationContext, latents):
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
|
@ -7,6 +7,7 @@ from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField
|
||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
|
||||
@ -18,7 +19,7 @@ class AddInvocation(BaseInvocation):
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context) -> IntegerOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=self.a + self.b)
|
||||
|
||||
|
||||
@ -29,7 +30,7 @@ class SubtractInvocation(BaseInvocation):
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context) -> IntegerOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=self.a - self.b)
|
||||
|
||||
|
||||
@ -40,7 +41,7 @@ class MultiplyInvocation(BaseInvocation):
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context) -> IntegerOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=self.a * self.b)
|
||||
|
||||
|
||||
@ -51,7 +52,7 @@ class DivideInvocation(BaseInvocation):
|
||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||
|
||||
def invoke(self, context) -> IntegerOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=int(self.a / self.b))
|
||||
|
||||
|
||||
@ -69,7 +70,7 @@ class RandomIntInvocation(BaseInvocation):
|
||||
low: int = InputField(default=0, description=FieldDescriptions.inclusive_low)
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high)
|
||||
|
||||
def invoke(self, context) -> IntegerOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||
|
||||
|
||||
@ -88,7 +89,7 @@ class RandomFloatInvocation(BaseInvocation):
|
||||
high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high)
|
||||
decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places)
|
||||
|
||||
def invoke(self, context) -> FloatOutput:
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
random_float = np.random.uniform(self.low, self.high)
|
||||
rounded_float = round(random_float, self.decimals)
|
||||
return FloatOutput(value=rounded_float)
|
||||
@ -110,7 +111,7 @@ class FloatToIntegerInvocation(BaseInvocation):
|
||||
default="Nearest", description="The method to use for rounding"
|
||||
)
|
||||
|
||||
def invoke(self, context) -> IntegerOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
if self.method == "Nearest":
|
||||
return IntegerOutput(value=round(self.value / self.multiple) * self.multiple)
|
||||
elif self.method == "Floor":
|
||||
@ -128,7 +129,7 @@ class RoundInvocation(BaseInvocation):
|
||||
value: float = InputField(default=0, description="The float value")
|
||||
decimals: int = InputField(default=0, description="The number of decimal places")
|
||||
|
||||
def invoke(self, context) -> FloatOutput:
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(value=round(self.value, self.decimals))
|
||||
|
||||
|
||||
@ -196,7 +197,7 @@ class IntegerMathInvocation(BaseInvocation):
|
||||
raise ValueError("Result of exponentiation is not an integer")
|
||||
return v
|
||||
|
||||
def invoke(self, context) -> IntegerOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
# Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9
|
||||
if self.operation == "ADD":
|
||||
return IntegerOutput(value=self.a + self.b)
|
||||
@ -270,7 +271,7 @@ class FloatMathInvocation(BaseInvocation):
|
||||
raise ValueError("Root operation resulted in a complex number")
|
||||
return v
|
||||
|
||||
def invoke(self, context) -> FloatOutput:
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
# Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9
|
||||
if self.operation == "ADD":
|
||||
return FloatOutput(value=self.a + self.b)
|
||||
|
@ -20,6 +20,7 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
@ -64,7 +65,7 @@ class MetadataItemInvocation(BaseInvocation):
|
||||
label: str = InputField(description=FieldDescriptions.metadata_item_label)
|
||||
value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any)
|
||||
|
||||
def invoke(self, context) -> MetadataItemOutput:
|
||||
def invoke(self, context: InvocationContext) -> MetadataItemOutput:
|
||||
return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value))
|
||||
|
||||
|
||||
@ -81,7 +82,7 @@ class MetadataInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.metadata_item_polymorphic
|
||||
)
|
||||
|
||||
def invoke(self, context) -> MetadataOutput:
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
if isinstance(self.items, MetadataItemField):
|
||||
# single metadata item
|
||||
data = {self.items.label: self.items.value}
|
||||
@ -100,7 +101,7 @@ class MergeMetadataInvocation(BaseInvocation):
|
||||
|
||||
collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection)
|
||||
|
||||
def invoke(self, context) -> MetadataOutput:
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
data = {}
|
||||
for item in self.collection:
|
||||
data.update(item.model_dump())
|
||||
@ -218,7 +219,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
description="The start value used for refiner denoising",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> MetadataOutput:
|
||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataOutput(
|
||||
|
@ -4,6 +4,7 @@ from typing import List, Optional
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
@ -109,7 +110,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context) -> ModelLoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
@ -221,7 +222,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
title="CLIP",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> LoraLoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
@ -310,7 +311,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
title="CLIP 2",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> SDXLLoraLoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
@ -393,7 +394,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> VAEOutput:
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
@ -448,7 +449,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||
|
||||
def invoke(self, context) -> SeamlessModeOutput:
|
||||
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
||||
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||
unet = copy.deepcopy(self.unet)
|
||||
vae = copy.deepcopy(self.vae)
|
||||
@ -484,6 +485,6 @@ class FreeUInvocation(BaseInvocation):
|
||||
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
|
||||
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
|
||||
|
||||
def invoke(self, context) -> UNetOutput:
|
||||
def invoke(self, context: InvocationContext) -> UNetOutput:
|
||||
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
|
||||
return UNetOutput(unet=self.unet)
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
@ -112,7 +113,7 @@ class NoiseInvocation(BaseInvocation):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
||||
def invoke(self, context) -> NoiseOutput:
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
noise = get_noise(
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
|
@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context) -> ConditioningOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
)
|
||||
@ -201,7 +201,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
def invoke(self, context) -> LatentsOutput:
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
||||
)
|
||||
|
||||
def invoke(self, context) -> ONNXModelLoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.ONNX
|
||||
|
@ -40,6 +40,7 @@ from easing_functions import (
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField
|
||||
@ -62,7 +63,7 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
||||
description="number of values to interpolate over (including start and stop)",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> FloatCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||
return FloatCollectionOutput(collection=param_list)
|
||||
|
||||
@ -130,7 +131,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||
|
||||
def invoke(self, context) -> FloatCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
log_diagnostics = False
|
||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||
|
@ -17,6 +17,7 @@ from invokeai.app.invocations.fields import (
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -59,7 +60,7 @@ class BooleanInvocation(BaseInvocation):
|
||||
|
||||
value: bool = InputField(default=False, description="The boolean value")
|
||||
|
||||
def invoke(self, context) -> BooleanOutput:
|
||||
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
||||
return BooleanOutput(value=self.value)
|
||||
|
||||
|
||||
@ -75,7 +76,7 @@ class BooleanCollectionInvocation(BaseInvocation):
|
||||
|
||||
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
|
||||
|
||||
def invoke(self, context) -> BooleanCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
return BooleanCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
@ -108,7 +109,7 @@ class IntegerInvocation(BaseInvocation):
|
||||
|
||||
value: int = InputField(default=0, description="The integer value")
|
||||
|
||||
def invoke(self, context) -> IntegerOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
return IntegerOutput(value=self.value)
|
||||
|
||||
|
||||
@ -124,7 +125,7 @@ class IntegerCollectionInvocation(BaseInvocation):
|
||||
|
||||
collection: list[int] = InputField(default=[], description="The collection of integer values")
|
||||
|
||||
def invoke(self, context) -> IntegerCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
@ -155,7 +156,7 @@ class FloatInvocation(BaseInvocation):
|
||||
|
||||
value: float = InputField(default=0.0, description="The float value")
|
||||
|
||||
def invoke(self, context) -> FloatOutput:
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
return FloatOutput(value=self.value)
|
||||
|
||||
|
||||
@ -171,7 +172,7 @@ class FloatCollectionInvocation(BaseInvocation):
|
||||
|
||||
collection: list[float] = InputField(default=[], description="The collection of float values")
|
||||
|
||||
def invoke(self, context) -> FloatCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
return FloatCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
@ -202,7 +203,7 @@ class StringInvocation(BaseInvocation):
|
||||
|
||||
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||
|
||||
def invoke(self, context) -> StringOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(value=self.value)
|
||||
|
||||
|
||||
@ -218,7 +219,7 @@ class StringCollectionInvocation(BaseInvocation):
|
||||
|
||||
collection: list[str] = InputField(default=[], description="The collection of string values")
|
||||
|
||||
def invoke(self, context) -> StringCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
return StringCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
@ -261,7 +262,7 @@ class ImageInvocation(
|
||||
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
@ -283,7 +284,7 @@ class ImageCollectionInvocation(BaseInvocation):
|
||||
|
||||
collection: list[ImageField] = InputField(description="The collection of image values")
|
||||
|
||||
def invoke(self, context) -> ImageCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
||||
return ImageCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
@ -346,7 +347,7 @@ class LatentsInvocation(BaseInvocation):
|
||||
|
||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||
|
||||
def invoke(self, context) -> LatentsOutput:
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.latents.get(self.latents.latents_name)
|
||||
|
||||
return LatentsOutput.build(self.latents.latents_name, latents)
|
||||
@ -366,7 +367,7 @@ class LatentsCollectionInvocation(BaseInvocation):
|
||||
description="The collection of latents tensors",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> LatentsCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
||||
return LatentsCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
@ -397,7 +398,7 @@ class ColorInvocation(BaseInvocation):
|
||||
|
||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
||||
|
||||
def invoke(self, context) -> ColorOutput:
|
||||
def invoke(self, context: InvocationContext) -> ColorOutput:
|
||||
return ColorOutput(color=self.color)
|
||||
|
||||
|
||||
@ -438,7 +439,7 @@ class ConditioningInvocation(BaseInvocation):
|
||||
|
||||
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
||||
|
||||
def invoke(self, context) -> ConditioningOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
|
||||
@ -457,7 +458,7 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
||||
description="The collection of conditioning tensors",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> ConditioningCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
||||
return ConditioningCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
|
@ -6,6 +6,7 @@ from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPrompt
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, UIComponent
|
||||
@ -29,7 +30,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||
|
||||
def invoke(self, context) -> StringCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
if self.combinatorial:
|
||||
generator = CombinatorialPromptGenerator()
|
||||
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
||||
@ -91,7 +92,7 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
break
|
||||
return prompts
|
||||
|
||||
def invoke(self, context) -> StringCollectionOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
prompts = self.promptsFromFile(
|
||||
self.file_path,
|
||||
self.pre_prompt,
|
||||
|
@ -1,4 +1,5 @@
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
@ -38,7 +39,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context) -> SDXLModelLoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
@ -127,7 +128,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context) -> SDXLRefinerModelLoaderOutput:
|
||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
import re
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -32,7 +34,7 @@ class StringSplitNegInvocation(BaseInvocation):
|
||||
|
||||
string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea)
|
||||
|
||||
def invoke(self, context) -> StringPosNegOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringPosNegOutput:
|
||||
p_string = ""
|
||||
n_string = ""
|
||||
brackets_depth = 0
|
||||
@ -76,7 +78,7 @@ class StringSplitInvocation(BaseInvocation):
|
||||
default="", description="Delimiter to spilt with. blank will split on the first whitespace"
|
||||
)
|
||||
|
||||
def invoke(self, context) -> String2Output:
|
||||
def invoke(self, context: InvocationContext) -> String2Output:
|
||||
result = self.string.split(self.delimiter, 1)
|
||||
if len(result) == 2:
|
||||
part1, part2 = result
|
||||
@ -94,7 +96,7 @@ class StringJoinInvocation(BaseInvocation):
|
||||
string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea)
|
||||
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
|
||||
|
||||
def invoke(self, context) -> StringOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
|
||||
|
||||
|
||||
@ -106,7 +108,7 @@ class StringJoinThreeInvocation(BaseInvocation):
|
||||
string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea)
|
||||
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
|
||||
|
||||
def invoke(self, context) -> StringOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or "")))
|
||||
|
||||
|
||||
@ -125,7 +127,7 @@ class StringReplaceInvocation(BaseInvocation):
|
||||
default=False, description="Use search string as a regex expression (non regex is case insensitive)"
|
||||
)
|
||||
|
||||
def invoke(self, context) -> StringOutput:
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
pattern = self.search_string or ""
|
||||
new_string = self.string or ""
|
||||
if len(pattern) > 0:
|
||||
|
@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_management.models.base import BaseModelType
|
||||
|
||||
|
||||
@ -89,7 +90,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context) -> T2IAdapterOutput:
|
||||
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
|
||||
return T2IAdapterOutput(
|
||||
t2i_adapter=T2IAdapterField(
|
||||
image=self.image,
|
||||
|
@ -13,6 +13,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.tiles.tiles import (
|
||||
calc_tiles_even_split,
|
||||
calc_tiles_min_overlap,
|
||||
@ -56,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation):
|
||||
description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> CalculateImageTilesOutput:
|
||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||
tiles = calc_tiles_with_overlap(
|
||||
image_height=self.image_height,
|
||||
image_width=self.image_width,
|
||||
@ -99,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
description="The overlap, in pixels, between adjacent tiles.",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> CalculateImageTilesOutput:
|
||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||
tiles = calc_tiles_even_split(
|
||||
image_height=self.image_height,
|
||||
image_width=self.image_width,
|
||||
@ -129,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
|
||||
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
|
||||
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
|
||||
|
||||
def invoke(self, context) -> CalculateImageTilesOutput:
|
||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=self.image_height,
|
||||
image_width=self.image_width,
|
||||
@ -174,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation):
|
||||
|
||||
tile: Tile = InputField(description="The tile to split into properties.")
|
||||
|
||||
def invoke(self, context) -> TileToPropertiesOutput:
|
||||
def invoke(self, context: InvocationContext) -> TileToPropertiesOutput:
|
||||
return TileToPropertiesOutput(
|
||||
coords_left=self.tile.coords.left,
|
||||
coords_right=self.tile.coords.right,
|
||||
@ -211,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation):
|
||||
image: ImageField = InputField(description="The tile image.")
|
||||
tile: Tile = InputField(description="The tile properties.")
|
||||
|
||||
def invoke(self, context) -> PairTileImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> PairTileImageOutput:
|
||||
return PairTileImageOutput(
|
||||
tile_with_image=TileWithImage(
|
||||
tile=self.tile,
|
||||
@ -247,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
||||
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
|
||||
)
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
images = [twi.image for twi in self.tiles_with_images]
|
||||
tiles = [twi.tile for twi in self.tiles_with_images]
|
||||
|
||||
|
@ -10,6 +10,7 @@ from pydantic import ConfigDict
|
||||
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
@ -42,7 +43,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, context) -> ImageOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
models_path = context.config.get().models_path
|
||||
|
||||
|
@ -17,6 +17,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# in 3.10 this would be "from types import NoneType"
|
||||
@ -201,7 +202,7 @@ class GraphInvocation(BaseInvocation):
|
||||
# TODO: figure out how to create a default here
|
||||
graph: "Graph" = InputField(description="The graph to run", default=None)
|
||||
|
||||
def invoke(self, context) -> GraphInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
return GraphInvocationOutput()
|
||||
|
||||
@ -227,7 +228,7 @@ class IterateInvocation(BaseInvocation):
|
||||
)
|
||||
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
||||
|
||||
def invoke(self, context) -> IterateInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||
"""Produces the outputs as values"""
|
||||
return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection))
|
||||
|
||||
@ -254,7 +255,7 @@ class CollectInvocation(BaseInvocation):
|
||||
description="The collection, will be provided on execution", default=[], ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context) -> CollectInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||
|
||||
|
@ -8,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
# Define test invocations before importing anything that uses invocations
|
||||
@ -20,7 +21,7 @@ class ListPassThroughInvocationOutput(BaseInvocationOutput):
|
||||
class ListPassThroughInvocation(BaseInvocation):
|
||||
collection: list[ImageField] = InputField(default=[])
|
||||
|
||||
def invoke(self, context) -> ListPassThroughInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput:
|
||||
return ListPassThroughInvocationOutput(collection=self.collection)
|
||||
|
||||
|
||||
@ -33,13 +34,13 @@ class PromptTestInvocationOutput(BaseInvocationOutput):
|
||||
class PromptTestInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="")
|
||||
|
||||
def invoke(self, context) -> PromptTestInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
return PromptTestInvocationOutput(prompt=self.prompt)
|
||||
|
||||
|
||||
@invocation("test_error", version="1.0.0")
|
||||
class ErrorInvocation(BaseInvocation):
|
||||
def invoke(self, context) -> PromptTestInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput:
|
||||
raise Exception("This invocation is supposed to fail")
|
||||
|
||||
|
||||
@ -53,7 +54,7 @@ class TextToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="")
|
||||
prompt2: str = InputField(default="")
|
||||
|
||||
def invoke(self, context) -> ImageTestInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
|
||||
|
||||
@ -62,7 +63,7 @@ class ImageToImageTestInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="")
|
||||
image: Union[ImageField, None] = InputField(default=None)
|
||||
|
||||
def invoke(self, context) -> ImageTestInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput:
|
||||
return ImageTestInvocationOutput(image=ImageField(image_name=self.id))
|
||||
|
||||
|
||||
@ -75,7 +76,7 @@ class PromptCollectionTestInvocationOutput(BaseInvocationOutput):
|
||||
class PromptCollectionTestInvocation(BaseInvocation):
|
||||
collection: list[str] = InputField()
|
||||
|
||||
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
return PromptCollectionTestInvocationOutput(collection=self.collection.copy())
|
||||
|
||||
|
||||
@ -88,7 +89,7 @@ class AnyTypeTestInvocationOutput(BaseInvocationOutput):
|
||||
class AnyTypeTestInvocation(BaseInvocation):
|
||||
value: Any = InputField(default=None)
|
||||
|
||||
def invoke(self, context) -> AnyTypeTestInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> AnyTypeTestInvocationOutput:
|
||||
return AnyTypeTestInvocationOutput(value=self.value)
|
||||
|
||||
|
||||
@ -96,7 +97,7 @@ class AnyTypeTestInvocation(BaseInvocation):
|
||||
class PolymorphicStringTestInvocation(BaseInvocation):
|
||||
value: Union[str, list[str]] = InputField(default="")
|
||||
|
||||
def invoke(self, context) -> PromptCollectionTestInvocationOutput:
|
||||
def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput:
|
||||
if isinstance(self.value, str):
|
||||
return PromptCollectionTestInvocationOutput(collection=[self.value])
|
||||
return PromptCollectionTestInvocationOutput(collection=self.value)
|
||||
|
Loading…
Reference in New Issue
Block a user