Merge branch 'main' into feat/controlnet-control-modes

This commit is contained in:
blessedcoolant 2023-06-15 03:18:41 +12:00
commit 6b8e88ad7f
46 changed files with 485 additions and 652 deletions

View File

@ -70,27 +70,25 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image") raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image") @images_router.delete("/{image_name}", operation_id="delete_image")
async def delete_image( async def delete_image(
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"), image_name: str = Path(description="The name of the image to delete"),
) -> None: ) -> None:
"""Deletes an image""" """Deletes an image"""
try: try:
ApiDependencies.invoker.services.images.delete(image_origin, image_name) ApiDependencies.invoker.services.images.delete(image_name)
except Exception as e: except Exception as e:
# TODO: Does this need any exception handling at all? # TODO: Does this need any exception handling at all?
pass pass
@images_router.patch( @images_router.patch(
"/{image_origin}/{image_name}", "/{image_name}",
operation_id="update_image", operation_id="update_image",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def update_image( async def update_image(
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"), image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body( image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image" description="The changes to apply to the image"
@ -99,32 +97,29 @@ async def update_image(
"""Updates an image""" """Updates an image"""
try: try:
return ApiDependencies.invoker.services.images.update( return ApiDependencies.invoker.services.images.update(image_name, image_changes)
image_origin, image_name, image_changes
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image") raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get( @images_router.get(
"/{image_origin}/{image_name}/metadata", "/{image_name}/metadata",
operation_id="get_image_metadata", operation_id="get_image_metadata",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def get_image_metadata( async def get_image_metadata(
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"), image_name: str = Path(description="The name of image to get"),
) -> ImageDTO: ) -> ImageDTO:
"""Gets an image's metadata""" """Gets an image's metadata"""
try: try:
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name) return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.get( @images_router.get(
"/{image_origin}/{image_name}", "/{image_name}",
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@ -136,15 +131,12 @@ async def get_image_metadata(
}, },
) )
async def get_image_full( async def get_image_full(
image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"), image_name: str = Path(description="The name of full-resolution image file to get"),
) -> FileResponse: ) -> FileResponse:
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name) path = ApiDependencies.invoker.services.images.get_path(image_name)
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -160,7 +152,7 @@ async def get_image_full(
@images_router.get( @images_router.get(
"/{image_origin}/{image_name}/thumbnail", "/{image_name}/thumbnail",
operation_id="get_image_thumbnail", operation_id="get_image_thumbnail",
response_class=Response, response_class=Response,
responses={ responses={
@ -172,14 +164,13 @@ async def get_image_full(
}, },
) )
async def get_image_thumbnail( async def get_image_thumbnail(
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse: ) -> FileResponse:
"""Gets a thumbnail image file""" """Gets a thumbnail image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path( path = ApiDependencies.invoker.services.images.get_path(
image_origin, image_name, thumbnail=True image_name, thumbnail=True
) )
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -192,25 +183,21 @@ async def get_image_thumbnail(
@images_router.get( @images_router.get(
"/{image_origin}/{image_name}/urls", "/{image_name}/urls",
operation_id="get_image_urls", operation_id="get_image_urls",
response_model=ImageUrlsDTO, response_model=ImageUrlsDTO,
) )
async def get_image_urls( async def get_image_urls(
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"), image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO: ) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL""" """Gets an image and thumbnail URL"""
try: try:
image_url = ApiDependencies.invoker.services.images.get_url( image_url = ApiDependencies.invoker.services.images.get_url(image_name)
image_origin, image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_url( thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name, thumbnail=True image_name, thumbnail=True
) )
return ImageUrlsDTO( return ImageUrlsDTO(
image_origin=image_origin,
image_name=image_name, image_name=image_name,
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,

View File

@ -196,9 +196,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image return image
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image( raw_image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(raw_image)
@ -219,10 +217,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
) )
"""Builds an ImageOutput and its ImageField""" """Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField( processed_image_field = ImageField(image_name=image_dto.image_name)
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
)
return ImageOutput( return ImageOutput(
image=processed_image_field, image=processed_image_field,
# width=processed_image.width, # width=processed_image.width,

View File

@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name mask = context.services.images.get_pil_image(self.mask.image_name)
)
mask = context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
# Convert to cv image/mask # Convert to cv image/mask
# TODO: consider making these utility functions # TODO: consider making these utility functions
@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -86,9 +86,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# loading controlnet image (currently requires pre-processed image) # loading controlnet image (currently requires pre-processed image)
control_image = ( control_image = (
None if self.control_image is None None if self.control_image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(self.control_image.image_name)
self.control_image.image_origin, self.control_image.image_name
)
) )
# loading controlnet model # loading controlnet model
if (self.control_model is None or self.control_model==''): if (self.control_model is None or self.control_model==''):
@ -128,10 +126,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -169,9 +164,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image = ( image = (
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
) )
if self.fit: if self.fit:
@ -209,10 +202,7 @@ class ImageToImageInvocation(TextToImageInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -282,14 +272,12 @@ class InpaintInvocation(ImageToImageInvocation):
image = ( image = (
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
) )
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name) else context.services.images.get_pil_image(self.mask.image_name)
) )
# Handle invalid model parameter # Handle invalid model parameter
@ -325,10 +313,7 @@ class InpaintInvocation(ImageToImageInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation):
) )
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=self.image.image_name),
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
width=image.width, width=image.width,
height=image.height, height=image.height,
) )
@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
if image: if image:
image.show() image.show()
# TODO: how to handle failure? # TODO: how to handle failure?
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=self.image.image_name),
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
width=image.width, width=image.width,
height=image.height, height=image.height,
) )
@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
image_crop = Image.new( image_crop = Image.new(
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0) mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image( base_image = context.services.images.get_pil_image(self.base_image.image_name)
self.base_image.image_origin, self.base_image.image_name image = context.services.images.get_pil_image(self.image.image_name)
)
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else ImageOps.invert( else ImageOps.invert(
context.services.images.get_pil_image( context.services.images.get_pil_image(self.mask.image_name)
self.mask.image_origin, self.mask.image_name
)
) )
) )
# TODO: probably shouldn't invert mask here... should user be required to do it? # TODO: probably shouldn't invert mask here... should user be required to do it?
@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
image_mask = image.split()[-1] image_mask = image.split()[-1]
if self.invert: if self.invert:
@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
) )
return MaskOutput( return MaskOutput(
mask=ImageField( mask=ImageField(image_name=image_dto.image_name),
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image( image1 = context.services.images.get_pil_image(self.image1.image_name)
self.image1.image_origin, self.image1.image_name image2 = context.services.images.get_pil_image(self.image2.image_name)
)
image2 = context.services.images.get_pil_image(
self.image2.image_origin, self.image2.image_name
)
multiply_image = ImageChops.multiply(image1, image2) multiply_image = ImageChops.multiply(image1, image2)
@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
channel_image = image.getchannel(self.channel) channel_image = image.getchannel(self.channel)
@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
converted_image = image.convert(self.mode) converted_image = image.convert(self.mode)
@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
blur = ( blur = (
ImageFilter.GaussianBlur(self.radius) ImageFilter.GaussianBlur(self.radius)
@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor) width = int(image.width * self.scale_factor)
@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max image_arr = image_arr * (self.max - self.min) + self.max
@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = ( image_arr = (
@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple()) solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
infilled = tile_fill_missing( infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size image.copy(), seed=self.seed, tile_size=self.tile_size
@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
if PatchMatch.patchmatch_available(): if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy()) infilled = infill_patchmatch(image.copy())
@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -316,8 +316,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device) torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_origin, input_image = context.services.images.get_pil_image(control_image_field.image_name)
control_image_field.image_name)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -500,10 +499,7 @@ class LatentsToImageInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -599,9 +595,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# image = context.services.images.get( # image = context.services.images.get(
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# ) # )
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
# TODO: this only really needs the vae # TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model) model_info = choose_model(context.services.model_manager, self.model)

View File

@ -2,8 +2,8 @@ from typing import Literal
from pydantic.fields import Field from pydantic.fields import Field
from .baseinvocation import BaseInvocationOutput from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput): class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt""" """Base class for invocations that output a prompt"""
@ -20,3 +20,38 @@ class PromptOutput(BaseInvocationOutput):
'prompt', 'prompt',
] ]
} }
class PromptCollectionOutput(BaseInvocationOutput):
"""Base class for invocations that output a collection of prompts"""
# fmt: off
type: Literal["prompt_collection_output"] = "prompt_collection_output"
prompt_collection: list[str] = Field(description="The output prompt collection")
count: int = Field(description="The size of the prompt collection")
# fmt: on
class Config:
schema_extra = {"required": ["type", "prompt_collection", "count"]}
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""
type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator"
)
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
if self.combinatorial:
generator = CombinatorialPromptGenerator()
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
else:
generator = RandomPromptGenerator()
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
upscale=None, upscale=None,
@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
upscale=(self.level, self.strength), upscale=(self.level, self.strength),
@ -53,10 +51,7 @@ class UpscaleInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel): class ImageField(BaseModel):
"""An image field used for passing image objects between invocations""" """An image field used for passing image objects between invocations"""
image_origin: ResourceOrigin = Field(
default=ResourceOrigin.INTERNAL, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image") image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config: class Config:
schema_extra = {"required": ["image_origin", "image_name"]} schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel): class ColorField(BaseModel):

View File

@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
@ -40,14 +39,12 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
@abstractmethod @abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: def get(self, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image.""" """Retrieves an image as PIL Image."""
pass pass
@abstractmethod @abstractmethod
def get_path( def get_path(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or thumbnail.""" """Gets the internal path to an image or thumbnail."""
pass pass
@ -62,7 +59,6 @@ class ImageFileStorageBase(ABC):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
@ -71,7 +67,7 @@ class ImageFileStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists).""" """Deletes an image and its thumbnail (if one exists)."""
pass pass
@ -79,31 +75,26 @@ class ImageFileStorageBase(ABC):
class DiskImageFileStorage(ImageFileStorageBase): class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk""" """Stores images on disk"""
__output_folder: str __output_folder: Path
__cache_ids: Queue # TODO: this is an incredibly naive cache __cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType] __cache: Dict[Path, PILImageType]
__max_cache_size: int __max_cache_size: int
def __init__(self, output_folder: str): def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder
self.__cache = dict() self.__cache = dict()
self.__cache_ids = Queue() self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True) self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / 'thumbnails'
# TODO: don't hard-code. get/save/delete should maybe take subpath? # Validate required output folders at launch
for image_origin in ResourceOrigin: self.__validate_storage_folders()
Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: def get(self, image_name: str) -> PILImageType:
try: try:
image_path = self.get_path(image_origin, image_name) image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path) cache_item = self.__get_cache(image_path)
if cache_item: if cache_item:
return cache_item return cache_item
@ -117,13 +108,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
image_path = self.get_path(image_origin, image_name) self.__validate_storage_folders()
image_path = self.get_path(image_name)
if metadata is not None: if metadata is not None:
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
@ -133,7 +124,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG") image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path) thumbnail_image.save(thumbnail_path)
@ -142,20 +133,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e: except Exception as e:
raise ImageFileSaveException from e raise ImageFileSaveException from e
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: def delete(self, image_name: str) -> None:
try: try:
basename = os.path.basename(image_name) image_path = self.get_path(image_name)
image_path = self.get_path(image_origin, basename)
if os.path.exists(image_path): if image_path.exists():
send2trash(image_path) send2trash(image_path)
if image_path in self.__cache: if image_path in self.__cache:
del self.__cache[image_path] del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, True) thumbnail_path = self.get_path(thumbnail_name, True)
if os.path.exists(thumbnail_path): if thumbnail_path.exists():
send2trash(thumbnail_path) send2trash(thumbnail_path)
if thumbnail_path in self.__cache: if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path] del self.__cache[thumbnail_path]
@ -163,41 +153,33 @@ class DiskImageFileStorage(ImageFileStorageBase):
raise ImageFileDeleteException from e raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
def get_path( def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False path = self.__output_folder / image_name
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if thumbnail: if thumbnail:
thumbnail_name = get_thumbnail_name(basename) thumbnail_name = get_thumbnail_name(image_name)
path = os.path.join( path = self.__thumbnails_folder / thumbnail_name
self.__output_folder, image_origin, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_origin, basename)
abspath = os.path.abspath(path) return path
return abspath def validate_path(self, path: str | Path) -> bool:
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail.""" """Validates the path given for an image or thumbnail."""
try: path = path if isinstance(path, Path) else Path(path)
os.stat(path) return path.exists()
return True
except: def __validate_storage_folders(self) -> None:
return False """Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)
def __get_cache(self, image_name: str) -> PILImageType | None: def __get_cache(self, image_name: Path) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name] return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType): def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache: if not image_name in self.__cache:
self.__cache[image_name] = image self.__cache[image_name] = image
self.__cache_ids.put( self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size: if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get() cache_id = self.__cache_ids.get()
if cache_id in self.__cache: if cache_id in self.__cache:

View File

@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import (
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]): class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results""" """Offset-paginated results"""
@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method # TODO: Implement an `update()` method
@abstractmethod @abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: def get(self, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC):
def update( def update(
self, self,
image_name: str, image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> None: ) -> None:
"""Updates an image record.""" """Updates an image record."""
@ -89,7 +89,7 @@ class ImageRecordStorageBase(ABC):
# TODO: The database has a nullable `deleted_at` column, currently unused. # TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage. # Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod @abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: def delete(self, image_name: str) -> None:
"""Deletes an image record.""" """Deletes an image record."""
pass pass
@ -196,9 +196,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
def get( def get(self, image_name: str) -> Union[ImageRecord, None]:
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
try: try:
self._lock.acquire() self._lock.acquire()
@ -225,7 +223,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def update( def update(
self, self,
image_name: str, image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> None: ) -> None:
try: try:
@ -294,9 +291,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
if categories is not None: if categories is not None:
## Convert the enum values to unique list of strings ## Convert the enum values to unique list of strings
category_strings = list( category_strings = list(map(lambda c: c.value, set(categories)))
map(lambda c: c.value, set(categories))
)
# Create the correct length of placeholders # Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings)) placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n" query_conditions += f"AND image_category IN ( {placeholders} )\n"
@ -337,7 +332,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
items=images, offset=offset, limit=limit, total=count items=images, offset=offset, limit=limit, total=count
) )
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: def delete(self, image_name: str) -> None:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -57,7 +57,6 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def update( def update(
self, self,
image_origin: ResourceOrigin,
image_name: str, image_name: str,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: def get_pil_image(self, image_name: str) -> PILImageType:
"""Gets an image as a PIL image.""" """Gets an image as a PIL image."""
pass pass
@abstractmethod @abstractmethod
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: def get_record(self, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod @abstractmethod
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image DTO.""" """Gets an image DTO."""
pass pass
@abstractmethod @abstractmethod
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str: def get_path(self, image_name: str) -> str:
"""Gets an image's path.""" """Gets an image's path."""
pass pass
@ -90,9 +89,7 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_url( def get_url(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets an image's or thumbnail's URL.""" """Gets an image's or thumbnail's URL."""
pass pass
@ -109,7 +106,7 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str): def delete(self, image_name: str):
"""Deletes an image.""" """Deletes an image."""
pass pass
@ -206,16 +203,13 @@ class ImageService(ImageServiceABC):
) )
self._services.files.save( self._services.files.save(
image_origin=image_origin,
image_name=image_name, image_name=image_name,
image=image, image=image,
metadata=metadata, metadata=metadata,
) )
image_url = self._services.urls.get_image_url(image_origin, image_name) image_url = self._services.urls.get_image_url(image_name)
thumbnail_url = self._services.urls.get_image_url( thumbnail_url = self._services.urls.get_image_url(image_name, True)
image_origin, image_name, True
)
return ImageDTO( return ImageDTO(
# Non-nullable fields # Non-nullable fields
@ -249,13 +243,12 @@ class ImageService(ImageServiceABC):
def update( def update(
self, self,
image_origin: ResourceOrigin,
image_name: str, image_name: str,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
try: try:
self._services.records.update(image_name, image_origin, changes) self._services.records.update(image_name, changes)
return self.get_dto(image_origin, image_name) return self.get_dto(image_name)
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to update image record") self._services.logger.error("Failed to update image record")
raise raise
@ -263,9 +256,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record") self._services.logger.error("Problem updating image record")
raise e raise e
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: def get_pil_image(self, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_origin, image_name) return self._services.files.get(image_name)
except ImageFileNotFoundException: except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self._services.logger.error("Failed to get image file")
raise raise
@ -273,9 +266,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file") self._services.logger.error("Problem getting image file")
raise e raise e
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: def get_record(self, image_name: str) -> ImageRecord:
try: try:
return self._services.records.get(image_origin, image_name) return self._services.records.get(image_name)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")
raise raise
@ -283,14 +276,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record") self._services.logger.error("Problem getting image record")
raise e raise e
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: def get_dto(self, image_name: str) -> ImageDTO:
try: try:
image_record = self._services.records.get(image_origin, image_name) image_record = self._services.records.get(image_name)
image_dto = image_record_to_dto( image_dto = image_record_to_dto(
image_record, image_record,
self._services.urls.get_image_url(image_origin, image_name), self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_origin, image_name, True), self._services.urls.get_image_url(image_name, True),
) )
return image_dto return image_dto
@ -301,11 +294,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO") self._services.logger.error("Problem getting image DTO")
raise e raise e
def get_path( def get_path(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try: try:
return self._services.files.get_path(image_origin, image_name, thumbnail) return self._services.files.get_path(image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
@ -317,11 +308,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem validating image path") self._services.logger.error("Problem validating image path")
raise e raise e
def get_url( def get_url(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try: try:
return self._services.urls.get_image_url(image_origin, image_name, thumbnail) return self._services.urls.get_image_url(image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
@ -347,10 +336,8 @@ class ImageService(ImageServiceABC):
map( map(
lambda r: image_record_to_dto( lambda r: image_record_to_dto(
r, r,
self._services.urls.get_image_url(r.image_origin, r.image_name), self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url( self._services.urls.get_image_url(r.image_name, True),
r.image_origin, r.image_name, True
),
), ),
results.items, results.items,
) )
@ -366,10 +353,10 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting paginated image DTOs") self._services.logger.error("Problem getting paginated image DTOs")
raise e raise e
def delete(self, image_origin: ResourceOrigin, image_name: str): def delete(self, image_name: str):
try: try:
self._services.files.delete(image_origin, image_name) self._services.files.delete(image_name)
self._services.records.delete(image_origin, image_name) self._services.records.delete(image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record") self._services.logger.error(f"Failed to delete image record")
raise raise

View File

@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
@ -70,24 +69,26 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
class DiskLatentsStorage(LatentsStorageBase): class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching""" """Stores latents in a folder on disk without caching"""
__output_folder: str __output_folder: str | Path
def __init__(self, output_folder: str): def __init__(self, output_folder: str | Path):
self.__output_folder = output_folder self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
Path(output_folder).mkdir(parents=True, exist_ok=True) self.__output_folder.mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor: def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name) latent_path = self.get_path(name)
return torch.load(latent_path) return torch.load(latent_path)
def save(self, name: str, data: torch.Tensor) -> None: def save(self, name: str, data: torch.Tensor) -> None:
self.__output_folder.mkdir(parents=True, exist_ok=True)
latent_path = self.get_path(name) latent_path = self.get_path(name)
torch.save(data, latent_path) torch.save(data, latent_path)
def delete(self, name: str) -> None: def delete(self, name: str) -> None:
latent_path = self.get_path(name) latent_path = self.get_path(name)
os.remove(latent_path) latent_path.unlink()
def get_path(self, name: str) -> str: def get_path(self, name: str) -> Path:
return os.path.join(self.__output_folder, name) return self.__output_folder / name

View File

@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel):
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_url: str = Field(description="The URL of the image.") image_url: str = Field(description="The URL of the image.")
"""The URL of the image.""" """The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.") thumbnail_url: str = Field(description="The URL of the image's thumbnail.")

View File

@ -1,17 +1,12 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC): class UrlServiceBase(ABC):
"""Responsible for building URLs for resources.""" """Responsible for building URLs for resources."""
@abstractmethod @abstractmethod
def get_image_url( def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the URL for an image or thumbnail.""" """Gets the URL for an image or thumbnail."""
pass pass
@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"): def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url self._base_url = base_url
def get_image_url( def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
image_basename = os.path.basename(image_name) image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py # These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail: if thumbnail:
return ( return f"{self._base_url}/images/{image_basename}/thumbnail"
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_origin.value}/{image_basename}" return f"{self._base_url}/images/{image_basename}"

View File

@ -34,10 +34,7 @@ export const addControlNetImageProcessedListener = () => {
[controlNet.processorNode.id]: { [controlNet.processorNode.id]: {
...controlNet.processorNode, ...controlNet.processorNode,
is_intermediate: true, is_intermediate: true,
image: pick(controlNet.controlImage, [ image: pick(controlNet.controlImage, ['image_name']),
'image_name',
'image_origin',
]),
}, },
}, },
}; };

View File

@ -25,7 +25,7 @@ export const addRequestedImageDeletionListener = () => {
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const { image, imageUsage } = action.payload; const { image, imageUsage } = action.payload;
const { image_name, image_origin } = image; const { image_name } = image;
const state = getState(); const state = getState();
const selectedImage = state.gallery.selectedImage; const selectedImage = state.gallery.selectedImage;
@ -79,9 +79,7 @@ export const addRequestedImageDeletionListener = () => {
dispatch(imageRemoved(image_name)); dispatch(imageRemoved(image_name));
// Delete from server // Delete from server
dispatch( dispatch(imageDeleted({ imageName: image_name }));
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
);
}, },
}); });
}; };

View File

@ -20,7 +20,6 @@ export const addImageMetadataReceivedFulfilledListener = () => {
dispatch( dispatch(
imageUpdated({ imageUpdated({
imageName: image.image_name, imageName: image.image_name,
imageOrigin: image.image_origin,
requestBody: { is_intermediate: false }, requestBody: { is_intermediate: false },
}) })
); );

View File

@ -36,13 +36,12 @@ export const addInvocationCompleteEventListener = () => {
// This complete event has an associated image output // This complete event has an associated image output
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_origin } = result.image; const { image_name } = result.image;
// Get its metadata // Get its metadata
dispatch( dispatch(
imageMetadataReceived({ imageMetadataReceived({
imageName: image_name, imageName: image_name,
imageOrigin: image_origin,
}) })
); );

View File

@ -11,12 +11,11 @@ export const addStagingAreaImageSavedListener = () => {
startAppListening({ startAppListening({
actionCreator: stagingAreaImageSaved, actionCreator: stagingAreaImageSaved,
effect: async (action, { dispatch, getState, take }) => { effect: async (action, { dispatch, getState, take }) => {
const { image_name, image_origin } = action.payload; const { image_name } = action.payload;
dispatch( dispatch(
imageUpdated({ imageUpdated({
imageName: image_name, imageName: image_name,
imageOrigin: image_origin,
requestBody: { requestBody: {
is_intermediate: false, is_intermediate: false,
}, },

View File

@ -80,11 +80,10 @@ export const addUpdateImageUrlsOnConnectListener = () => {
`Fetching new image URLs for ${allUsedImages.length} images` `Fetching new image URLs for ${allUsedImages.length} images`
); );
allUsedImages.forEach(({ image_name, image_origin }) => { allUsedImages.forEach(({ image_name }) => {
dispatch( dispatch(
imageUrlsReceived({ imageUrlsReceived({
imageName: image_name, imageName: image_name,
imageOrigin: image_origin,
}) })
); );
}); });

View File

@ -116,7 +116,6 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type // Update the base node with the image name and type
baseNode.image = { baseNode.image = {
image_name: baseImageDTO.image_name, image_name: baseImageDTO.image_name,
image_origin: baseImageDTO.image_origin,
}; };
} }
@ -143,7 +142,6 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type // Update the base node with the image name and type
baseNode.mask = { baseNode.mask = {
image_name: maskImageDTO.image_name, image_name: maskImageDTO.image_name,
image_origin: maskImageDTO.image_origin,
}; };
} }
@ -160,7 +158,6 @@ export const addUserInvokedCanvasListener = () => {
dispatch( dispatch(
imageUpdated({ imageUpdated({
imageName: baseNode.image.image_name, imageName: baseNode.image.image_name,
imageOrigin: baseNode.image.image_origin,
requestBody: { session_id: sessionId }, requestBody: { session_id: sessionId },
}) })
); );
@ -171,7 +168,6 @@ export const addUserInvokedCanvasListener = () => {
dispatch( dispatch(
imageUpdated({ imageUpdated({
imageName: baseNode.mask.image_name, imageName: baseNode.mask.image_name,
imageOrigin: baseNode.mask.image_origin,
requestBody: { session_id: sessionId }, requestBody: { session_id: sessionId },
}) })
); );

View File

@ -866,8 +866,7 @@ export const canvasSlice = createSlice({
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } = const { image_name, image_url, thumbnail_url } = action.payload;
action.payload;
state.layerState.objects.forEach((object) => { state.layerState.objects.forEach((object) => {
if (object.kind === 'image') { if (object.kind === 'image') {

View File

@ -59,8 +59,7 @@ export const gallerySlice = createSlice({
} }
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } = const { image_name, image_url, thumbnail_url } = action.payload;
action.payload;
if (state.selectedImage?.image_name === image_name) { if (state.selectedImage?.image_name === image_name) {
state.selectedImage.image_url = image_url; state.selectedImage.image_url = image_url;

View File

@ -86,8 +86,7 @@ const imagesSlice = createSlice({
imagesAdapter.removeOne(state, imageName); imagesAdapter.removeOne(state, imageName);
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } = const { image_name, image_url, thumbnail_url } = action.payload;
action.payload;
imagesAdapter.updateOne(state, { imagesAdapter.updateOne(state, {
id: image_name, id: image_name,

View File

@ -103,8 +103,7 @@ const nodesSlice = createSlice({
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } = const { image_name, image_url, thumbnail_url } = action.payload;
action.payload;
state.nodes.forEach((node) => { state.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => { forEach(node.data.inputs, (input) => {

View File

@ -68,17 +68,15 @@ export const addControlNetToLinearGraph = (
if (processedControlImage && processorType !== 'none') { if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image // We've already processed the image in the app, so we can just use the processed image
const { image_name, image_origin } = processedControlImage; const { image_name } = processedControlImage;
controlNetNode.image = { controlNetNode.image = {
image_name, image_name,
image_origin,
}; };
} else if (controlImage) { } else if (controlImage) {
// The control image is preprocessed // The control image is preprocessed
const { image_name, image_origin } = controlImage; const { image_name } = controlImage;
controlNetNode.image = { controlNetNode.image = {
image_name, image_name,
image_origin,
}; };
} else { } else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly // Skip ControlNets without an unprocessed image - should never happen if everything is working correctly

View File

@ -354,7 +354,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
type: 'img_resize', type: 'img_resize',
image: { image: {
image_name: initialImage.image_name, image_name: initialImage.image_name,
image_origin: initialImage.image_origin,
}, },
is_intermediate: true, is_intermediate: true,
height, height,
@ -392,7 +391,6 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', { set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
image_name: initialImage.image_name, image_name: initialImage.image_name,
image_origin: initialImage.image_origin,
}); });
// Pass the image's dimensions to the `NOISE` node // Pass the image's dimensions to the `NOISE` node

View File

@ -57,8 +57,7 @@ export const buildImg2ImgNode = (
} }
imageToImageNode.image = { imageToImageNode.image = {
image_name: initialImage.name, image_name: initialImage.image_name,
image_origin: initialImage.type,
}; };
} }

View File

@ -1,11 +1,5 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { isObject } from 'lodash-es'; import { ImageDTO } from 'services/api';
import { ImageDTO, ResourceOrigin } from 'services/api';
export type ImageNameAndOrigin = {
image_name: string;
image_origin: ResourceOrigin;
};
export const initialImageSelected = createAction<ImageDTO | string | undefined>( export const initialImageSelected = createAction<ImageDTO | string | undefined>(
'generation/initialImageSelected' 'generation/initialImageSelected'

View File

@ -234,8 +234,7 @@ export const generationSlice = createSlice({
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } = const { image_name, image_url, thumbnail_url } = action.payload;
action.payload;
if (state.initialImage?.image_name === image_name) { if (state.initialImage?.image_name === image_name) {
state.initialImage.image_url = image_url; state.initialImage.image_url = image_url;

View File

@ -24,6 +24,7 @@ export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
export type { DiffusersModelInfo } from './models/DiffusersModelInfo'; export type { DiffusersModelInfo } from './models/DiffusersModelInfo';
export type { DivideInvocation } from './models/DivideInvocation'; export type { DivideInvocation } from './models/DivideInvocation';
export type { DynamicPromptInvocation } from './models/DynamicPromptInvocation';
export type { Edge } from './models/Edge'; export type { Edge } from './models/Edge';
export type { EdgeConnection } from './models/EdgeConnection'; export type { EdgeConnection } from './models/EdgeConnection';
export type { FloatCollectionOutput } from './models/FloatCollectionOutput'; export type { FloatCollectionOutput } from './models/FloatCollectionOutput';
@ -86,6 +87,7 @@ export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedRe
export type { ParamFloatInvocation } from './models/ParamFloatInvocation'; export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation'; export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
export type { PromptOutput } from './models/PromptOutput'; export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation'; export type { RandomIntInvocation } from './models/RandomIntInvocation';
export type { RandomRangeInvocation } from './models/RandomRangeInvocation'; export type { RandomRangeInvocation } from './models/RandomRangeInvocation';

View File

@ -0,0 +1,31 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator
*/
export type DynamicPromptInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'dynamic_prompt';
/**
* The prompt to parse with dynamicprompts
*/
prompt: string;
/**
* The number of prompts to generate
*/
max_prompts?: number;
/**
* Whether to use the combinatorial generator
*/
combinatorial?: boolean;
};

View File

@ -28,3 +28,4 @@ export type FloatLinearRangeInvocation = {
*/ */
steps?: number; steps?: number;
}; };

View File

@ -10,6 +10,7 @@ import type { ContentShuffleImageProcessorInvocation } from './ContentShuffleIma
import type { ControlNetInvocation } from './ControlNetInvocation'; import type { ControlNetInvocation } from './ControlNetInvocation';
import type { CvInpaintInvocation } from './CvInpaintInvocation'; import type { CvInpaintInvocation } from './CvInpaintInvocation';
import type { DivideInvocation } from './DivideInvocation'; import type { DivideInvocation } from './DivideInvocation';
import type { DynamicPromptInvocation } from './DynamicPromptInvocation';
import type { Edge } from './Edge'; import type { Edge } from './Edge';
import type { FloatLinearRangeInvocation } from './FloatLinearRangeInvocation'; import type { FloatLinearRangeInvocation } from './FloatLinearRangeInvocation';
import type { GraphInvocation } from './GraphInvocation'; import type { GraphInvocation } from './GraphInvocation';
@ -71,9 +72,10 @@ export type Graph = {
/** /**
* The nodes in this graph * The nodes in this graph
*/ */
nodes?: Record<string, (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation)>; nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
/** /**
* The connections between nodes and their fields in this graph * The connections between nodes and their fields in this graph
*/ */
edges?: Array<Edge>; edges?: Array<Edge>;
}; };

View File

@ -16,6 +16,7 @@ import type { IterateInvocationOutput } from './IterateInvocationOutput';
import type { LatentsOutput } from './LatentsOutput'; import type { LatentsOutput } from './LatentsOutput';
import type { MaskOutput } from './MaskOutput'; import type { MaskOutput } from './MaskOutput';
import type { NoiseOutput } from './NoiseOutput'; import type { NoiseOutput } from './NoiseOutput';
import type { PromptCollectionOutput } from './PromptCollectionOutput';
import type { PromptOutput } from './PromptOutput'; import type { PromptOutput } from './PromptOutput';
/** /**
@ -45,7 +46,7 @@ export type GraphExecutionState = {
/** /**
* The results of node executions * The results of node executions
*/ */
results: Record<string, (IntCollectionOutput | FloatCollectionOutput | CompelOutput | ImageOutput | MaskOutput | ControlOutput | LatentsOutput | NoiseOutput | IntOutput | FloatOutput | PromptOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>; results: Record<string, (ImageOutput | MaskOutput | ControlOutput | PromptOutput | PromptCollectionOutput | CompelOutput | IntOutput | FloatOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | FloatCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
/** /**
* Errors raised when executing nodes * Errors raised when executing nodes
*/ */
@ -59,3 +60,4 @@ export type GraphExecutionState = {
*/ */
source_prepared_mapping: Record<string, Array<string>>; source_prepared_mapping: Record<string, Array<string>>;
}; };

View File

@ -14,10 +14,6 @@ export type ImageDTO = {
* The unique name of the image. * The unique name of the image.
*/ */
image_name: string; image_name: string;
/**
* The type of the image.
*/
image_origin: ResourceOrigin;
/** /**
* The URL of the image. * The URL of the image.
*/ */
@ -26,6 +22,10 @@ export type ImageDTO = {
* The URL of the image's thumbnail. * The URL of the image's thumbnail.
*/ */
thumbnail_url: string; thumbnail_url: string;
/**
* The type of the image.
*/
image_origin: ResourceOrigin;
/** /**
* The category of the image. * The category of the image.
*/ */

View File

@ -2,16 +2,10 @@
/* tslint:disable */ /* tslint:disable */
/* eslint-disable */ /* eslint-disable */
import type { ResourceOrigin } from './ResourceOrigin';
/** /**
* An image field used for passing image objects between invocations * An image field used for passing image objects between invocations
*/ */
export type ImageField = { export type ImageField = {
/**
* The type of the image
*/
image_origin: ResourceOrigin;
/** /**
* The name of the image * The name of the image
*/ */

View File

@ -2,8 +2,6 @@
/* tslint:disable */ /* tslint:disable */
/* eslint-disable */ /* eslint-disable */
import type { ResourceOrigin } from './ResourceOrigin';
/** /**
* The URLs for an image and its thumbnail. * The URLs for an image and its thumbnail.
*/ */
@ -12,10 +10,6 @@ export type ImageUrlsDTO = {
* The unique name of the image. * The unique name of the image.
*/ */
image_name: string; image_name: string;
/**
* The type of the image.
*/
image_origin: ResourceOrigin;
/** /**
* The URL of the image. * The URL of the image.
*/ */

View File

@ -0,0 +1,19 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Base class for invocations that output a collection of prompts
*/
export type PromptCollectionOutput = {
type: 'prompt_collection_output';
/**
* The output prompt collection
*/
prompt_collection: Array<string>;
/**
* The size of the prompt collection
*/
count: number;
};

View File

@ -56,3 +56,4 @@ export type StepParamEasingInvocation = {
*/ */
show_easing_plot?: boolean; show_easing_plot?: boolean;
}; };

View File

@ -22,33 +22,33 @@ export class ImagesService {
* @throws ApiError * @throws ApiError
*/ */
public static listImagesWithMetadata({ public static listImagesWithMetadata({
imageOrigin, imageOrigin,
categories, categories,
isIntermediate, isIntermediate,
offset, offset,
limit = 10, limit = 10,
}: { }: {
/** /**
* The origin of images to list * The origin of images to list
*/ */
imageOrigin?: ResourceOrigin, imageOrigin?: ResourceOrigin,
/** /**
* The categories of image to include * The categories of image to include
*/ */
categories?: Array<ImageCategory>, categories?: Array<ImageCategory>,
/** /**
* Whether to list intermediate images * Whether to list intermediate images
*/ */
isIntermediate?: boolean, isIntermediate?: boolean,
/** /**
* The page offset * The page offset
*/ */
offset?: number, offset?: number,
/** /**
* The number of images per page * The number of images per page
*/ */
limit?: number, limit?: number,
}): CancelablePromise<OffsetPaginatedResults_ImageDTO_> { }): CancelablePromise<OffsetPaginatedResults_ImageDTO_> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'GET', method: 'GET',
url: '/api/v1/images/', url: '/api/v1/images/',
@ -72,25 +72,25 @@ limit?: number,
* @throws ApiError * @throws ApiError
*/ */
public static uploadImage({ public static uploadImage({
imageCategory, imageCategory,
isIntermediate, isIntermediate,
formData, formData,
sessionId, sessionId,
}: { }: {
/** /**
* The category of the image * The category of the image
*/ */
imageCategory: ImageCategory, imageCategory: ImageCategory,
/** /**
* Whether this is an intermediate image * Whether this is an intermediate image
*/ */
isIntermediate: boolean, isIntermediate: boolean,
formData: Body_upload_image, formData: Body_upload_image,
/** /**
* The session ID associated with this upload, if any * The session ID associated with this upload, if any
*/ */
sessionId?: string, sessionId?: string,
}): CancelablePromise<ImageDTO> { }): CancelablePromise<ImageDTO> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'POST', method: 'POST',
url: '/api/v1/images/', url: '/api/v1/images/',
@ -115,23 +115,17 @@ sessionId?: string,
* @throws ApiError * @throws ApiError
*/ */
public static getImageFull({ public static getImageFull({
imageOrigin, imageName,
imageName, }: {
}: { /**
/** * The name of full-resolution image file to get
* The type of full-resolution image file to get */
*/ imageName: string,
imageOrigin: ResourceOrigin, }): CancelablePromise<any> {
/**
* The name of full-resolution image file to get
*/
imageName: string,
}): CancelablePromise<any> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'GET', method: 'GET',
url: '/api/v1/images/{image_origin}/{image_name}', url: '/api/v1/images/{image_name}',
path: { path: {
'image_origin': imageOrigin,
'image_name': imageName, 'image_name': imageName,
}, },
errors: { errors: {
@ -148,23 +142,17 @@ imageName: string,
* @throws ApiError * @throws ApiError
*/ */
public static deleteImage({ public static deleteImage({
imageOrigin, imageName,
imageName, }: {
}: { /**
/** * The name of the image to delete
* The origin of image to delete */
*/ imageName: string,
imageOrigin: ResourceOrigin, }): CancelablePromise<any> {
/**
* The name of the image to delete
*/
imageName: string,
}): CancelablePromise<any> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'DELETE', method: 'DELETE',
url: '/api/v1/images/{image_origin}/{image_name}', url: '/api/v1/images/{image_name}',
path: { path: {
'image_origin': imageOrigin,
'image_name': imageName, 'image_name': imageName,
}, },
errors: { errors: {
@ -180,25 +168,19 @@ imageName: string,
* @throws ApiError * @throws ApiError
*/ */
public static updateImage({ public static updateImage({
imageOrigin, imageName,
imageName, requestBody,
requestBody, }: {
}: { /**
/** * The name of the image to update
* The origin of image to update */
*/ imageName: string,
imageOrigin: ResourceOrigin, requestBody: ImageRecordChanges,
/** }): CancelablePromise<ImageDTO> {
* The name of the image to update
*/
imageName: string,
requestBody: ImageRecordChanges,
}): CancelablePromise<ImageDTO> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'PATCH', method: 'PATCH',
url: '/api/v1/images/{image_origin}/{image_name}', url: '/api/v1/images/{image_name}',
path: { path: {
'image_origin': imageOrigin,
'image_name': imageName, 'image_name': imageName,
}, },
body: requestBody, body: requestBody,
@ -216,23 +198,17 @@ requestBody: ImageRecordChanges,
* @throws ApiError * @throws ApiError
*/ */
public static getImageMetadata({ public static getImageMetadata({
imageOrigin, imageName,
imageName, }: {
}: { /**
/** * The name of image to get
* The origin of image to get */
*/ imageName: string,
imageOrigin: ResourceOrigin, }): CancelablePromise<ImageDTO> {
/**
* The name of image to get
*/
imageName: string,
}): CancelablePromise<ImageDTO> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'GET', method: 'GET',
url: '/api/v1/images/{image_origin}/{image_name}/metadata', url: '/api/v1/images/{image_name}/metadata',
path: { path: {
'image_origin': imageOrigin,
'image_name': imageName, 'image_name': imageName,
}, },
errors: { errors: {
@ -248,23 +224,17 @@ imageName: string,
* @throws ApiError * @throws ApiError
*/ */
public static getImageThumbnail({ public static getImageThumbnail({
imageOrigin, imageName,
imageName, }: {
}: { /**
/** * The name of thumbnail image file to get
* The origin of thumbnail image file to get */
*/ imageName: string,
imageOrigin: ResourceOrigin, }): CancelablePromise<any> {
/**
* The name of thumbnail image file to get
*/
imageName: string,
}): CancelablePromise<any> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'GET', method: 'GET',
url: '/api/v1/images/{image_origin}/{image_name}/thumbnail', url: '/api/v1/images/{image_name}/thumbnail',
path: { path: {
'image_origin': imageOrigin,
'image_name': imageName, 'image_name': imageName,
}, },
errors: { errors: {
@ -281,23 +251,17 @@ imageName: string,
* @throws ApiError * @throws ApiError
*/ */
public static getImageUrls({ public static getImageUrls({
imageOrigin, imageName,
imageName, }: {
}: { /**
/** * The name of the image whose URL to get
* The origin of the image whose URL to get */
*/ imageName: string,
imageOrigin: ResourceOrigin, }): CancelablePromise<ImageUrlsDTO> {
/**
* The name of the image whose URL to get
*/
imageName: string,
}): CancelablePromise<ImageUrlsDTO> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'GET', method: 'GET',
url: '/api/v1/images/{image_origin}/{image_name}/urls', url: '/api/v1/images/{image_name}/urls',
path: { path: {
'image_origin': imageOrigin,
'image_name': imageName, 'image_name': imageName,
}, },
errors: { errors: {

View File

@ -9,6 +9,7 @@ import type { ContentShuffleImageProcessorInvocation } from '../models/ContentSh
import type { ControlNetInvocation } from '../models/ControlNetInvocation'; import type { ControlNetInvocation } from '../models/ControlNetInvocation';
import type { CvInpaintInvocation } from '../models/CvInpaintInvocation'; import type { CvInpaintInvocation } from '../models/CvInpaintInvocation';
import type { DivideInvocation } from '../models/DivideInvocation'; import type { DivideInvocation } from '../models/DivideInvocation';
import type { DynamicPromptInvocation } from '../models/DynamicPromptInvocation';
import type { Edge } from '../models/Edge'; import type { Edge } from '../models/Edge';
import type { FloatLinearRangeInvocation } from '../models/FloatLinearRangeInvocation'; import type { FloatLinearRangeInvocation } from '../models/FloatLinearRangeInvocation';
import type { Graph } from '../models/Graph'; import type { Graph } from '../models/Graph';
@ -78,23 +79,23 @@ export class SessionsService {
* @throws ApiError * @throws ApiError
*/ */
public static listSessions({ public static listSessions({
page, page,
perPage = 10, perPage = 10,
query = '', query = '',
}: { }: {
/** /**
* The page of results to get * The page of results to get
*/ */
page?: number, page?: number,
/** /**
* The number of results per page * The number of results per page
*/ */
perPage?: number, perPage?: number,
/** /**
* The query string to search for * The query string to search for
*/ */
query?: string, query?: string,
}): CancelablePromise<PaginatedResults_GraphExecutionState_> { }): CancelablePromise<PaginatedResults_GraphExecutionState_> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'GET', method: 'GET',
url: '/api/v1/sessions/', url: '/api/v1/sessions/',
@ -116,10 +117,10 @@ query?: string,
* @throws ApiError * @throws ApiError
*/ */
public static createSession({ public static createSession({
requestBody, requestBody,
}: { }: {
requestBody?: Graph, requestBody?: Graph,
}): CancelablePromise<GraphExecutionState> { }): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'POST', method: 'POST',
url: '/api/v1/sessions/', url: '/api/v1/sessions/',
@ -139,13 +140,13 @@ requestBody?: Graph,
* @throws ApiError * @throws ApiError
*/ */
public static getSession({ public static getSession({
sessionId, sessionId,
}: { }: {
/** /**
* The id of the session to get * The id of the session to get
*/ */
sessionId: string, sessionId: string,
}): CancelablePromise<GraphExecutionState> { }): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'GET', method: 'GET',
url: '/api/v1/sessions/{session_id}', url: '/api/v1/sessions/{session_id}',
@ -166,15 +167,15 @@ sessionId: string,
* @throws ApiError * @throws ApiError
*/ */
public static addNode({ public static addNode({
sessionId, sessionId,
requestBody, requestBody,
}: { }: {
/** /**
* The id of the session * The id of the session
*/ */
sessionId: string, sessionId: string,
requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation), requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<string> { }): CancelablePromise<string> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'POST', method: 'POST',
url: '/api/v1/sessions/{session_id}/nodes', url: '/api/v1/sessions/{session_id}/nodes',
@ -198,20 +199,20 @@ requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation |
* @throws ApiError * @throws ApiError
*/ */
public static updateNode({ public static updateNode({
sessionId, sessionId,
nodePath, nodePath,
requestBody, requestBody,
}: { }: {
/** /**
* The id of the session * The id of the session
*/ */
sessionId: string, sessionId: string,
/** /**
* The path to the node in the graph * The path to the node in the graph
*/ */
nodePath: string, nodePath: string,
requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | CompelInvocation | LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | CvInpaintInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | RestoreFaceInvocation | UpscaleInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | ImageToImageInvocation | LatentsToLatentsInvocation | InpaintInvocation), requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<GraphExecutionState> { }): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'PUT', method: 'PUT',
url: '/api/v1/sessions/{session_id}/nodes/{node_path}', url: '/api/v1/sessions/{session_id}/nodes/{node_path}',
@ -236,18 +237,18 @@ requestBody: (RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation |
* @throws ApiError * @throws ApiError
*/ */
public static deleteNode({ public static deleteNode({
sessionId, sessionId,
nodePath, nodePath,
}: { }: {
/** /**
* The id of the session * The id of the session
*/ */
sessionId: string, sessionId: string,
/** /**
* The path to the node to delete * The path to the node to delete
*/ */
nodePath: string, nodePath: string,
}): CancelablePromise<GraphExecutionState> { }): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'DELETE', method: 'DELETE',
url: '/api/v1/sessions/{session_id}/nodes/{node_path}', url: '/api/v1/sessions/{session_id}/nodes/{node_path}',
@ -270,15 +271,15 @@ nodePath: string,
* @throws ApiError * @throws ApiError
*/ */
public static addEdge({ public static addEdge({
sessionId, sessionId,
requestBody, requestBody,
}: { }: {
/** /**
* The id of the session * The id of the session
*/ */
sessionId: string, sessionId: string,
requestBody: Edge, requestBody: Edge,
}): CancelablePromise<GraphExecutionState> { }): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'POST', method: 'POST',
url: '/api/v1/sessions/{session_id}/edges', url: '/api/v1/sessions/{session_id}/edges',
@ -302,33 +303,33 @@ requestBody: Edge,
* @throws ApiError * @throws ApiError
*/ */
public static deleteEdge({ public static deleteEdge({
sessionId, sessionId,
fromNodeId, fromNodeId,
fromField, fromField,
toNodeId, toNodeId,
toField, toField,
}: { }: {
/** /**
* The id of the session * The id of the session
*/ */
sessionId: string, sessionId: string,
/** /**
* The id of the node the edge is coming from * The id of the node the edge is coming from
*/ */
fromNodeId: string, fromNodeId: string,
/** /**
* The field of the node the edge is coming from * The field of the node the edge is coming from
*/ */
fromField: string, fromField: string,
/** /**
* The id of the node the edge is going to * The id of the node the edge is going to
*/ */
toNodeId: string, toNodeId: string,
/** /**
* The field of the node the edge is going to * The field of the node the edge is going to
*/ */
toField: string, toField: string,
}): CancelablePromise<GraphExecutionState> { }): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'DELETE', method: 'DELETE',
url: '/api/v1/sessions/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}', url: '/api/v1/sessions/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}',
@ -354,18 +355,18 @@ toField: string,
* @throws ApiError * @throws ApiError
*/ */
public static invokeSession({ public static invokeSession({
sessionId, sessionId,
all = false, all = false,
}: { }: {
/** /**
* The id of the session to invoke * The id of the session to invoke
*/ */
sessionId: string, sessionId: string,
/** /**
* Whether or not to invoke all remaining invocations * Whether or not to invoke all remaining invocations
*/ */
all?: boolean, all?: boolean,
}): CancelablePromise<any> { }): CancelablePromise<any> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'PUT', method: 'PUT',
url: '/api/v1/sessions/{session_id}/invoke', url: '/api/v1/sessions/{session_id}/invoke',
@ -390,13 +391,13 @@ all?: boolean,
* @throws ApiError * @throws ApiError
*/ */
public static cancelSessionInvoke({ public static cancelSessionInvoke({
sessionId, sessionId,
}: { }: {
/** /**
* The id of the session to cancel * The id of the session to cancel
*/ */
sessionId: string, sessionId: string,
}): CancelablePromise<any> { }): CancelablePromise<any> {
return __request(OpenAPI, { return __request(OpenAPI, {
method: 'DELETE', method: 'DELETE',
url: '/api/v1/sessions/{session_id}/invoke', url: '/api/v1/sessions/{session_id}/invoke',

View File

@ -44,6 +44,7 @@ dependencies = [
"datasets", "datasets",
"diffusers[torch]~=0.17.0", "diffusers[torch]~=0.17.0",
"dnspython==2.2.1", "dnspython==2.2.1",
"dynamicprompts",
"easing-functions", "easing-functions",
"einops", "einops",
"eventlet", "eventlet",