Merge branch 'main' into feat/controlnet-control-modes

This commit is contained in:
blessedcoolant 2023-06-15 03:18:41 +12:00
commit 6b8e88ad7f
46 changed files with 485 additions and 652 deletions

View File

@ -70,27 +70,25 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
@images_router.delete("/{image_name}", operation_id="delete_image")
async def delete_image(
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
try:
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
ApiDependencies.invoker.services.images.delete(image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@images_router.patch(
"/{image_origin}/{image_name}",
"/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
@ -99,32 +97,29 @@ async def update_image(
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(
image_origin, image_name, image_changes
)
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get(
"/{image_origin}/{image_name}/metadata",
"/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageDTO,
)
async def get_image_metadata(
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_origin}/{image_name}",
"/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
@ -136,15 +131,12 @@ async def get_image_metadata(
},
)
async def get_image_full(
image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
path = ApiDependencies.invoker.services.images.get_path(image_name)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@ -160,7 +152,7 @@ async def get_image_full(
@images_router.get(
"/{image_origin}/{image_name}/thumbnail",
"/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
@ -172,14 +164,13 @@ async def get_image_full(
},
)
async def get_image_thumbnail(
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_origin, image_name, thumbnail=True
image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@ -192,25 +183,21 @@ async def get_image_thumbnail(
@images_router.get(
"/{image_origin}/{image_name}/urls",
"/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
async def get_image_urls(
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name
)
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name, thumbnail=True
image_name, thumbnail=True
)
return ImageUrlsDTO(
image_origin=image_origin,
image_name=image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,

View File

@ -196,9 +196,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
raw_image = context.services.images.get_pil_image(self.image.image_name)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@ -219,10 +217,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
)
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
# width=processed_image.width,

View File

@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
mask = context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
mask = context.services.images.get_pil_image(self.mask.image_name)
# Convert to cv image/mask
# TODO: consider making these utility functions
@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -86,9 +86,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get_pil_image(
self.control_image.image_origin, self.control_image.image_name
)
else context.services.images.get_pil_image(self.control_image.image_name)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
@ -128,10 +126,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -169,9 +164,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
else context.services.images.get_pil_image(self.image.image_name)
)
if self.fit:
@ -209,10 +202,7 @@ class ImageToImageInvocation(TextToImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -282,14 +272,12 @@ class InpaintInvocation(ImageToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
else context.services.images.get_pil_image(self.image.image_name)
)
mask = (
None
if self.mask is None
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
else context.services.images.get_pil_image(self.mask.image_name)
)
# Handle invalid model parameter
@ -325,10 +313,7 @@ class InpaintInvocation(ImageToImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation):
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if image:
image.show()
# TODO: how to handle failure?
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_crop = Image.new(
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(
self.base_image.image_origin, self.base_image.image_name
)
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
base_image = context.services.images.get_pil_image(self.base_image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
context.services.images.get_pil_image(self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_mask = image.split()[-1]
if self.invert:
@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
)
return MaskOutput(
mask=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
mask=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(
self.image1.image_origin, self.image1.image_name
)
image2 = context.services.images.get_pil_image(
self.image2.image_origin, self.image2.image_name
)
image1 = context.services.images.get_pil_image(self.image1.image_name)
image2 = context.services.images.get_pil_image(self.image2.image_name)
multiply_image = ImageChops.multiply(image1, image2)
@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
channel_image = image.getchannel(self.channel)
@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
converted_image = image.convert(self.mode)
@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
blur = (
ImageFilter.GaussianBlur(self.radius)
@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max
@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = (
@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy())
@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -316,8 +316,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
control_image_field.image_name)
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@ -500,10 +499,7 @@ class LatentsToImageInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -599,9 +595,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)

View File

@ -2,8 +2,8 @@ from typing import Literal
from pydantic.fields import Field
from .baseinvocation import BaseInvocationOutput
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
@ -20,3 +20,38 @@ class PromptOutput(BaseInvocationOutput):
'prompt',
]
}
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
# fmt: off
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompt collection")
# fmt: on
class Config:
schema_extra = {"required": ["type", "prompt_collection", "count"]}
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator"
)
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
@ -53,10 +51,7 @@ class UpscaleInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_origin: ResourceOrigin = Field(
default=ResourceOrigin.INTERNAL, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_origin", "image_name"]}
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):

View File

@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@ -40,14 +39,12 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get(self, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
@ -62,7 +59,6 @@ class ImageFileStorageBase(ABC):
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
@ -71,7 +67,7 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
@ -79,31 +75,26 @@ class ImageFileStorageBase(ABC):
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
__output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType]
__cache: Dict[Path, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
def __init__(self, output_folder: str | Path):
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_origin in ResourceOrigin:
Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
# Validate required output folders at launch
self.__validate_storage_folders()
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get(self, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_origin, image_name)
image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
@ -117,13 +108,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_origin, image_name)
self.__validate_storage_folders()
image_path = self.get_path(image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
@ -133,7 +124,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
@ -142,20 +133,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e:
raise ImageFileSaveException from e
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(image_origin, basename)
image_path = self.get_path(image_name)
if os.path.exists(image_path):
if image_path.exists():
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
thumbnail_path = self.get_path(thumbnail_name, True)
if os.path.exists(thumbnail_path):
if thumbnail_path.exists():
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
@ -163,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase):
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
path = self.__output_folder / image_name
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_origin, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_origin, basename)
thumbnail_name = get_thumbnail_name(image_name)
path = self.__thumbnails_folder / thumbnail_name
abspath = os.path.abspath(path)
return path
return abspath
def validate_path(self, path: str) -> bool:
def validate_path(self, path: str | Path) -> bool:
"""Validates the path given for an image or thumbnail."""
try:
os.stat(path)
return True
except:
return False
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def __get_cache(self, image_name: str) -> PILImageType | None:
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: Path) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType):
def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:

View File

@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import (
T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results"""
@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC):
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
@ -89,7 +89,7 @@ class ImageRecordStorageBase(ABC):
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
"""Deletes an image record."""
pass
@ -196,9 +196,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
def get(
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
def get(self, image_name: str) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
@ -225,7 +223,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
try:
@ -294,9 +291,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
if categories is not None:
## Convert the enum values to unique list of strings
category_strings = list(
map(lambda c: c.value, set(categories))
)
category_strings = list(map(lambda c: c.value, set(categories)))
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
@ -337,7 +332,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
items=images, offset=offset, limit=limit, total=count
)
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(

View File

@ -57,7 +57,6 @@ class ImageServiceABC(ABC):
@abstractmethod
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get_pil_image(self, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get_record(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
def get_path(self, image_name: str) -> str:
"""Gets an image's path."""
pass
@ -90,9 +89,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@ -109,7 +106,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str):
def delete(self, image_name: str):
"""Deletes an image."""
pass
@ -206,16 +203,13 @@ class ImageService(ImageServiceABC):
)
self._services.files.save(
image_origin=image_origin,
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url(
image_origin, image_name, True
)
image_url = self._services.urls.get_image_url(image_name)
thumbnail_url = self._services.urls.get_image_url(image_name, True)
return ImageDTO(
# Non-nullable fields
@ -249,13 +243,12 @@ class ImageService(ImageServiceABC):
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, image_origin, changes)
return self.get_dto(image_origin, image_name)
self._services.records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@ -263,9 +256,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_origin, image_name)
return self._services.files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
@ -273,9 +266,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_origin, image_name)
return self._services.records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
@ -283,14 +276,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_origin, image_name)
image_record = self._services.records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_origin, image_name),
self._services.urls.get_image_url(image_origin, image_name, True),
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
)
return image_dto
@ -301,11 +294,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.files.get_path(image_origin, image_name, thumbnail)
return self._services.files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@ -317,11 +308,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem validating image path")
raise e
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
return self._services.urls.get_image_url(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@ -347,10 +336,8 @@ class ImageService(ImageServiceABC):
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url(
r.image_origin, r.image_name, True
),
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
),
results.items,
)
@ -366,10 +353,10 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_origin: ResourceOrigin, image_name: str):
def delete(self, image_name: str):
try:
self._services.files.delete(image_origin, image_name)
self._services.records.delete(image_origin, image_name)
self._services.files.delete(image_name)
self._services.records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
__output_folder: str | Path
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder.mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
latent_path.unlink()
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)
def get_path(self, name: str) -> Path:
return self.__output_folder / name

View File

@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel):
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_url: str = Field(description="The URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")

View File

@ -1,17 +1,12 @@
import os
from abc import ABC, abstractmethod
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC):
"""Responsible for building URLs for resources."""
@abstractmethod
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
return f"{self._base_url}/images/{image_basename}"

View File

@ -34,10 +34,7 @@ export const addControlNetImageProcessedListener = () => {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
is_intermediate: true,
image: pick(controlNet.controlImage, [
'image_name',
'image_origin',
]),
image: pick(controlNet.controlImage, ['image_name']),
},
},
};

View File

@ -25,7 +25,7 @@ export const addRequestedImageDeletionListener = () => {
effect: (action, { dispatch, getState }) => {
const { image, imageUsage } = action.payload;
const { image_name, image_origin } = image;
const { image_name } = image;
const state = getState();
const selectedImage = state.gallery.selectedImage;
@ -79,9 +79,7 @@ export const addRequestedImageDeletionListener = () => {
dispatch(imageRemoved(image_name));
// Delete from server
dispatch(
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
);
dispatch(imageDeleted({ imageName: image_name }));
},
});
};

View File

@ -20,7 +20,6 @@ export const addImageMetadataReceivedFulfilledListener = () => {
dispatch(
imageUpdated({
imageName: image.image_name,
imageOrigin: image.image_origin,
requestBody: { is_intermediate: false },
})
);

View File

@ -36,13 +36,12 @@ export const addInvocationCompleteEventListener = () => {
// This complete event has an associated image output
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_origin } = result.image;
const { image_name } = result.image;
// Get its metadata
dispatch(
imageMetadataReceived({
imageName: image_name,
imageOrigin: image_origin,
})
);

View File

@ -11,12 +11,11 @@ export const addStagingAreaImageSavedListener = () => {
startAppListening({
actionCreator: stagingAreaImageSaved,
effect: async (action, { dispatch, getState, take }) => {
const { image_name, image_origin } = action.payload;
const { image_name } = action.payload;
dispatch(
imageUpdated({
imageName: image_name,
imageOrigin: image_origin,
requestBody: {
is_intermediate: false,
},

View File

@ -80,11 +80,10 @@ export const addUpdateImageUrlsOnConnectListener = () => {
`Fetching new image URLs for ${allUsedImages.length} images`
);
allUsedImages.forEach(({ image_name, image_origin }) => {
allUsedImages.forEach(({ image_name }) => {
dispatch(
imageUrlsReceived({
imageName: image_name,
imageOrigin: image_origin,
})
);
});

View File

@ -116,7 +116,6 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type
baseNode.image = {
image_name: baseImageDTO.image_name,
image_origin: baseImageDTO.image_origin,
};
}
@ -143,7 +142,6 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type
baseNode.mask = {
image_name: maskImageDTO.image_name,
image_origin: maskImageDTO.image_origin,
};
}
@ -160,7 +158,6 @@ export const addUserInvokedCanvasListener = () => {
dispatch(
imageUpdated({
imageName: baseNode.image.image_name,
imageOrigin: baseNode.image.image_origin,
requestBody: { session_id: sessionId },
})
);
@ -171,7 +168,6 @@ export const addUserInvokedCanvasListener = () => {
dispatch(
imageUpdated({
imageName: baseNode.mask.image_name,
imageOrigin: baseNode.mask.image_origin,
requestBody: { session_id: sessionId },
})
);

View File

@ -866,8 +866,7 @@ export const canvasSlice = createSlice({
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
state.layerState.objects.forEach((object) => {
if (object.kind === 'image') {

View File

@ -59,8 +59,7 @@ export const gallerySlice = createSlice({
}
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
if (state.selectedImage?.image_name === image_name) {
state.selectedImage.image_url = image_url;

View File

@ -86,8 +86,7 @@ const imagesSlice = createSlice({
imagesAdapter.removeOne(state, imageName);
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
imagesAdapter.updateOne(state, {
id: image_name,

View File

@ -103,8 +103,7 @@ const nodesSlice = createSlice({
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
state.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => {

View File

@ -68,17 +68,15 @@ export const addControlNetToLinearGraph = (
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
const { image_name, image_origin } = processedControlImage;
const { image_name } = processedControlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else if (controlImage) {
// The control image is preprocessed
const { image_name, image_origin } = controlImage;
const { image_name } = controlImage;
controlNetNode.image = {
image_name,
image_origin,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly

View File

@ -354,7 +354,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
type: 'img_resize',
image: {
image_name: initialImage.image_name,
image_origin: initialImage.image_origin,
},
is_intermediate: true,
height,
@ -392,7 +391,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
image_name: initialImage.image_name,
image_origin: initialImage.image_origin,
});
// Pass the image's dimensions to the `NOISE` node

View File

@ -57,8 +57,7 @@ export const buildImg2ImgNode = (
}
imageToImageNode.image = {
image_name: initialImage.name,
image_origin: initialImage.type,
image_name: initialImage.image_name,
};
}

View File

@ -1,11 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { isObject } from 'lodash-es';
import { ImageDTO, ResourceOrigin } from 'services/api';
export type ImageNameAndOrigin = {
image_name: string;
image_origin: ResourceOrigin;
};
import { ImageDTO } from 'services/api';
export const initialImageSelected = createAction<ImageDTO | string | undefined>(
'generation/initialImageSelected'

View File

@ -234,8 +234,7 @@ export const generationSlice = createSlice({
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;
const { image_name, image_url, thumbnail_url } = action.payload;
if (state.initialImage?.image_name === image_name) {
state.initialImage.image_url = image_url;

View File

@ -24,6 +24,7 @@ export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
export type { DivideInvocation } from './models/DivideInvocation';
export type { DynamicPromptInvocation } from './models/DynamicPromptInvocation';
export type { Edge } from './models/Edge';
export type { EdgeConnection } from './models/EdgeConnection';
export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
@ -86,6 +87,7 @@ export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedRe
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation';
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';

View File

@ -0,0 +1,31 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator
*/
export type DynamicPromptInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'dynamic_prompt';
/**
* The prompt to parse with dynamicprompts
*/
prompt: string;
/**
* The number of prompts to generate
*/
max_prompts?: number;
/**
* Whether to use the combinatorial generator
*/
combinatorial?: boolean;
};

View File

@ -28,3 +28,4 @@ export type FloatLinearRangeInvocation = {
*/
steps?: number;
};

View File

@ -10,6 +10,7 @@ import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleIma
import type { ControlNetInvocation } from './ControlNetInvocation';
import type { CvInpaintInvocation } from './CvInpaintInvocation';
import type { DivideInvocation } from './DivideInvocation';
import type { DynamicPromptInvocation } from './DynamicPromptInvocation';
import type { Edge } from './Edge';
import type { FloatLinearRangeInvocation } from './FloatLinearRangeInvocation';
import type { GraphInvocation } from './GraphInvocation';
@ -71,9 +72,10 @@ export type Graph = {
/**
* The nodes in this graph
*/
nodes?: Record<string, (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation)>;
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
/**
* The connections between nodes and their fields in this graph
*/
edges?: Array<Edge>;
};

View File

@ -16,6 +16,7 @@ import type { IterateInvocationOutput } from './IterateInvocationOutput';
import type { LatentsOutput } from './LatentsOutput';
import type { MaskOutput } from './MaskOutput';
import type { NoiseOutput } from './NoiseOutput';
import type { PromptCollectionOutput } from './PromptCollectionOutput';
import type { PromptOutput } from './PromptOutput';
/**
@ -45,7 +46,7 @@ export type GraphExecutionState = {
/**
* The results of node executions
*/
results: Record<string, (IntCollectionOutput | FloatCollectionOutput | CompelOutput | ImageOutput | MaskOutput | ControlOutput | LatentsOutput | NoiseOutput | IntOutput | FloatOutput | PromptOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
results: Record<string, (ImageOutput | MaskOutput | ControlOutput | PromptOutput | PromptCollectionOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
/**
* Errors raised when executing nodes
*/
@ -59,3 +60,4 @@ export type GraphExecutionState = {
*/
source_prepared_mapping: Record<string, Array<string>>;
};

View File

@ -14,10 +14,6 @@ export type ImageDTO = {
* The unique name of the image.
*/
image_name: string;
/**
* The type of the image.
*/
image_origin: ResourceOrigin;
/**
* The URL of the image.
*/
@ -26,6 +22,10 @@ export type ImageDTO = {
* The URL of the image's thumbnail.
*/
thumbnail_url: string;
/**
* The type of the image.
*/
image_origin: ResourceOrigin;
/**
* The category of the image.
*/

View File

@ -2,16 +2,10 @@
/* tslint:disable */
/* eslint-disable */
import type { ResourceOrigin } from './ResourceOrigin';
/**
* An image field used for passing image objects between invocations
*/
export type ImageField = {
/**
* The type of the image
*/
image_origin: ResourceOrigin;
/**
* The name of the image
*/

View File

@ -2,8 +2,6 @@
/* tslint:disable */
/* eslint-disable */
import type { ResourceOrigin } from './ResourceOrigin';
/**
* The URLs for an image and its thumbnail.
*/
@ -12,10 +10,6 @@ export type ImageUrlsDTO = {
* The unique name of the image.
*/
image_name: string;
/**
* The type of the image.
*/
image_origin: ResourceOrigin;
/**
* The URL of the image.
*/

View File

@ -0,0 +1,19 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Base class for invocations that output a collection of prompts
*/
export type PromptCollectionOutput = {
type: 'prompt_collection_output';
/**
* The output prompt collection
*/
prompt_collection: Array<string>;
/**
* The size of the prompt collection
*/
count: number;
};

View File

@ -56,3 +56,4 @@ export type StepParamEasingInvocation = {
*/
show_easing_plot?: boolean;
};

View File

@ -22,33 +22,33 @@ export class ImagesService {
* @throws ApiError
*/
public static listImagesWithMetadata({
imageOrigin,
categories,
isIntermediate,
offset,
limit = 10,
}: {
/**
imageOrigin,
categories,
isIntermediate,
offset,
limit = 10,
}: {
/**
* The origin of images to list
*/
imageOrigin?: ResourceOrigin,
/**
imageOrigin?: ResourceOrigin,
/**
* The categories of image to include
*/
categories?: Array<ImageCategory>,
/**
categories?: Array<ImageCategory>,
/**
* Whether to list intermediate images
*/
isIntermediate?: boolean,
/**
isIntermediate?: boolean,
/**
* The page offset
*/
offset?: number,
/**
offset?: number,
/**
* The number of images per page
*/
limit?: number,
}): CancelablePromise<OffsetPaginatedResults_ImageDTO_> {
limit?: number,
}): CancelablePromise<OffsetPaginatedResults_ImageDTO_> {
return __request(OpenAPI, {
method: 'GET',
url: '/api/v1/images/',
@ -72,25 +72,25 @@ limit?: number,
* @throws ApiError
*/
public static uploadImage({
imageCategory,
isIntermediate,
formData,
sessionId,
}: {
/**
imageCategory,
isIntermediate,
formData,
sessionId,
}: {
/**
* The category of the image
*/
imageCategory: ImageCategory,
/**
imageCategory: ImageCategory,
/**
* Whether this is an intermediate image
*/
isIntermediate: boolean,
formData: Body_upload_image,
/**
isIntermediate: boolean,
formData: Body_upload_image,
/**
* The session ID associated with this upload, if any
*/
sessionId?: string,
}): CancelablePromise<ImageDTO> {
sessionId?: string,
}): CancelablePromise<ImageDTO> {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/images/',
@ -115,23 +115,17 @@ sessionId?: string,
* @throws ApiError
*/
public static getImageFull({
imageOrigin,
imageName,
}: {
/**
* The type of full-resolution image file to get
*/
imageOrigin: ResourceOrigin,
/**
imageName,
}: {
/**
* The name of full-resolution image file to get
*/
imageName: string,
}): CancelablePromise<any> {
imageName: string,
}): CancelablePromise<any> {
return __request(OpenAPI, {
method: 'GET',
url: '/api/v1/images/{image_origin}/{image_name}',
url: '/api/v1/images/{image_name}',
path: {
'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {
@ -148,23 +142,17 @@ imageName: string,
* @throws ApiError
*/
public static deleteImage({
imageOrigin,
imageName,
}: {
/**
* The origin of image to delete
*/
imageOrigin: ResourceOrigin,
/**
imageName,
}: {
/**
* The name of the image to delete
*/
imageName: string,
}): CancelablePromise<any> {
imageName: string,
}): CancelablePromise<any> {
return __request(OpenAPI, {
method: 'DELETE',
url: '/api/v1/images/{image_origin}/{image_name}',
url: '/api/v1/images/{image_name}',
path: {
'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {
@ -180,25 +168,19 @@ imageName: string,
* @throws ApiError
*/
public static updateImage({
imageOrigin,
imageName,
requestBody,
}: {
/**
* The origin of image to update
*/
imageOrigin: ResourceOrigin,
/**
imageName,
requestBody,
}: {
/**
* The name of the image to update
*/
imageName: string,
requestBody: ImageRecordChanges,
}): CancelablePromise<ImageDTO> {
imageName: string,
requestBody: ImageRecordChanges,
}): CancelablePromise<ImageDTO> {
return __request(OpenAPI, {
method: 'PATCH',
url: '/api/v1/images/{image_origin}/{image_name}',
url: '/api/v1/images/{image_name}',
path: {
'image_origin': imageOrigin,
'image_name': imageName,
},
body: requestBody,
@ -216,23 +198,17 @@ requestBody: ImageRecordChanges,
* @throws ApiError
*/
public static getImageMetadata({
imageOrigin,
imageName,
}: {
/**
* The origin of image to get
*/
imageOrigin: ResourceOrigin,
/**
imageName,
}: {
/**
* The name of image to get
*/
imageName: string,
}): CancelablePromise<ImageDTO> {
imageName: string,
}): CancelablePromise<ImageDTO> {
return __request(OpenAPI, {
method: 'GET',
url: '/api/v1/images/{image_origin}/{image_name}/metadata',
url: '/api/v1/images/{image_name}/metadata',
path: {
'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {
@ -248,23 +224,17 @@ imageName: string,
* @throws ApiError
*/
public static getImageThumbnail({
imageOrigin,
imageName,
}: {
/**
* The origin of thumbnail image file to get
*/
imageOrigin: ResourceOrigin,
/**
imageName,
}: {
/**
* The name of thumbnail image file to get
*/
imageName: string,
}): CancelablePromise<any> {
imageName: string,
}): CancelablePromise<any> {
return __request(OpenAPI, {
method: 'GET',
url: '/api/v1/images/{image_origin}/{image_name}/thumbnail',
url: '/api/v1/images/{image_name}/thumbnail',
path: {
'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {
@ -281,23 +251,17 @@ imageName: string,
* @throws ApiError
*/
public static getImageUrls({
imageOrigin,
imageName,
}: {
/**
* The origin of the image whose URL to get
*/
imageOrigin: ResourceOrigin,
/**
imageName,
}: {
/**
* The name of the image whose URL to get
*/
imageName: string,
}): CancelablePromise<ImageUrlsDTO> {
imageName: string,
}): CancelablePromise<ImageUrlsDTO> {
return __request(OpenAPI, {
method: 'GET',
url: '/api/v1/images/{image_origin}/{image_name}/urls',
url: '/api/v1/images/{image_name}/urls',
path: {
'image_origin': imageOrigin,
'image_name': imageName,
},
errors: {

View File

@ -9,6 +9,7 @@ import type { ContentShuffleImageProcessorInvocation } from '../models/ContentSh
import type { ControlNetInvocation } from '../models/ControlNetInvocation';
import type { CvInpaintInvocation } from '../models/CvInpaintInvocation';
import type { DivideInvocation } from '../models/DivideInvocation';
import type { DynamicPromptInvocation } from '../models/DynamicPromptInvocation';
import type { Edge } from '../models/Edge';
import type { FloatLinearRangeInvocation } from '../models/FloatLinearRangeInvocation';
import type { Graph } from '../models/Graph';
@ -78,23 +79,23 @@ export class SessionsService {
* @throws ApiError
*/
public static listSessions({
page,
perPage = 10,
query = '',
}: {
/**
page,
perPage = 10,
query = '',
}: {
/**
* The page of results to get
*/
page?: number,
/**
page?: number,
/**
* The number of results per page
*/
perPage?: number,
/**
perPage?: number,
/**
* The query string to search for
*/
query?: string,
}): CancelablePromise<PaginatedResults_GraphExecutionState_> {
query?: string,
}): CancelablePromise<PaginatedResults_GraphExecutionState_> {
return __request(OpenAPI, {
method: 'GET',
url: '/api/v1/sessions/',
@ -116,10 +117,10 @@ query?: string,
* @throws ApiError
*/
public static createSession({
requestBody,
}: {
requestBody?: Graph,
}): CancelablePromise<GraphExecutionState> {
requestBody,
}: {
requestBody?: Graph,
}): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/sessions/',
@ -139,13 +140,13 @@ requestBody?: Graph,
* @throws ApiError
*/
public static getSession({
sessionId,
}: {
/**
sessionId,
}: {
/**
* The id of the session to get
*/
sessionId: string,
}): CancelablePromise<GraphExecutionState> {
sessionId: string,
}): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, {
method: 'GET',
url: '/api/v1/sessions/{session_id}',
@ -166,15 +167,15 @@ sessionId: string,
* @throws ApiError
*/
public static addNode({
sessionId,
requestBody,
}: {
/**
sessionId,
requestBody,
}: {
/**
* The id of the session
*/
sessionId: string,
requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation),
}): CancelablePromise<string> {
sessionId: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<string> {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/sessions/{session_id}/nodes',
@ -198,20 +199,20 @@ requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation |
* @throws ApiError
*/
public static updateNode({
sessionId,
nodePath,
requestBody,
}: {
/**
sessionId,
nodePath,
requestBody,
}: {
/**
* The id of the session
*/
sessionId: string,
/**
sessionId: string,
/**
* The path to the node in the graph
*/
nodePath: string,
requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation),
}): CancelablePromise<GraphExecutionState> {
nodePath: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, {
method: 'PUT',
url: '/api/v1/sessions/{session_id}/nodes/{node_path}',
@ -236,18 +237,18 @@ requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation |
* @throws ApiError
*/
public static deleteNode({
sessionId,
nodePath,
}: {
/**
sessionId,
nodePath,
}: {
/**
* The id of the session
*/
sessionId: string,
/**
sessionId: string,
/**
* The path to the node to delete
*/
nodePath: string,
}): CancelablePromise<GraphExecutionState> {
nodePath: string,
}): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, {
method: 'DELETE',
url: '/api/v1/sessions/{session_id}/nodes/{node_path}',
@ -270,15 +271,15 @@ nodePath: string,
* @throws ApiError
*/
public static addEdge({
sessionId,
requestBody,
}: {
/**
sessionId,
requestBody,
}: {
/**
* The id of the session
*/
sessionId: string,
requestBody: Edge,
}): CancelablePromise<GraphExecutionState> {
sessionId: string,
requestBody: Edge,
}): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, {
method: 'POST',
url: '/api/v1/sessions/{session_id}/edges',
@ -302,33 +303,33 @@ requestBody: Edge,
* @throws ApiError
*/
public static deleteEdge({
sessionId,
fromNodeId,
fromField,
toNodeId,
toField,
}: {
/**
sessionId,
fromNodeId,
fromField,
toNodeId,
toField,
}: {
/**
* The id of the session
*/
sessionId: string,
/**
sessionId: string,
/**
* The id of the node the edge is coming from
*/
fromNodeId: string,
/**
fromNodeId: string,
/**
* The field of the node the edge is coming from
*/
fromField: string,
/**
fromField: string,
/**
* The id of the node the edge is going to
*/
toNodeId: string,
/**
toNodeId: string,
/**
* The field of the node the edge is going to
*/
toField: string,
}): CancelablePromise<GraphExecutionState> {
toField: string,
}): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, {
method: 'DELETE',
url: '/api/v1/sessions/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
@ -354,18 +355,18 @@ toField: string,
* @throws ApiError
*/
public static invokeSession({
sessionId,
all = false,
}: {
/**
sessionId,
all = false,
}: {
/**
* The id of the session to invoke
*/
sessionId: string,
/**
sessionId: string,
/**
* Whether or not to invoke all remaining invocations
*/
all?: boolean,
}): CancelablePromise<any> {
all?: boolean,
}): CancelablePromise<any> {
return __request(OpenAPI, {
method: 'PUT',
url: '/api/v1/sessions/{session_id}/invoke',
@ -390,13 +391,13 @@ all?: boolean,
* @throws ApiError
*/
public static cancelSessionInvoke({
sessionId,
}: {
/**
sessionId,
}: {
/**
* The id of the session to cancel
*/
sessionId: string,
}): CancelablePromise<any> {
sessionId: string,
}): CancelablePromise<any> {
return __request(OpenAPI, {
method: 'DELETE',
url: '/api/v1/sessions/{session_id}/invoke',

View File

@ -44,6 +44,7 @@ dependencies = [
"datasets",
"diffusers[torch]~=0.17.0",
"dnspython==2.2.1",
"dynamicprompts",
"easing-functions",
"einops",
"eventlet",