feat(nodes): remove image_origin from most places

- remove `image_origin` from most places where we interact with images
- consolidate image file storage into a single `images/` dir

Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads.

It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images.

The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful.

So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
This commit is contained in:
psychedelicious
2023-06-14 21:40:09 +10:00
parent 1e08d865c9
commit a1773197e9
15 changed files with 124 additions and 299 deletions

View File

@ -70,27 +70,25 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image") raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image") @images_router.delete("/{image_name}", operation_id="delete_image")
async def delete_image( async def delete_image(
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"), image_name: str = Path(description="The name of the image to delete"),
) -> None: ) -> None:
"""Deletes an image""" """Deletes an image"""
try: try:
ApiDependencies.invoker.services.images.delete(image_origin, image_name) ApiDependencies.invoker.services.images.delete(image_name)
except Exception as e: except Exception as e:
# TODO: Does this need any exception handling at all? # TODO: Does this need any exception handling at all?
pass pass
@images_router.patch( @images_router.patch(
"/{image_origin}/{image_name}", "/{image_name}",
operation_id="update_image", operation_id="update_image",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def update_image( async def update_image(
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"), image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body( image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image" description="The changes to apply to the image"
@ -99,32 +97,29 @@ async def update_image(
"""Updates an image""" """Updates an image"""
try: try:
return ApiDependencies.invoker.services.images.update( return ApiDependencies.invoker.services.images.update(image_name, image_changes)
image_origin, image_name, image_changes
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image") raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get( @images_router.get(
"/{image_origin}/{image_name}/metadata", "/{image_name}/metadata",
operation_id="get_image_metadata", operation_id="get_image_metadata",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def get_image_metadata( async def get_image_metadata(
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"), image_name: str = Path(description="The name of image to get"),
) -> ImageDTO: ) -> ImageDTO:
"""Gets an image's metadata""" """Gets an image's metadata"""
try: try:
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name) return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.get( @images_router.get(
"/{image_origin}/{image_name}", "/{image_name}",
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@ -136,15 +131,12 @@ async def get_image_metadata(
}, },
) )
async def get_image_full( async def get_image_full(
image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"), image_name: str = Path(description="The name of full-resolution image file to get"),
) -> FileResponse: ) -> FileResponse:
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name) path = ApiDependencies.invoker.services.images.get_path(image_name)
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -160,7 +152,7 @@ async def get_image_full(
@images_router.get( @images_router.get(
"/{image_origin}/{image_name}/thumbnail", "/{image_name}/thumbnail",
operation_id="get_image_thumbnail", operation_id="get_image_thumbnail",
response_class=Response, response_class=Response,
responses={ responses={
@ -172,14 +164,13 @@ async def get_image_full(
}, },
) )
async def get_image_thumbnail( async def get_image_thumbnail(
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse: ) -> FileResponse:
"""Gets a thumbnail image file""" """Gets a thumbnail image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path( path = ApiDependencies.invoker.services.images.get_path(
image_origin, image_name, thumbnail=True image_name, thumbnail=True
) )
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -192,25 +183,21 @@ async def get_image_thumbnail(
@images_router.get( @images_router.get(
"/{image_origin}/{image_name}/urls", "/{image_name}/urls",
operation_id="get_image_urls", operation_id="get_image_urls",
response_model=ImageUrlsDTO, response_model=ImageUrlsDTO,
) )
async def get_image_urls( async def get_image_urls(
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"), image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO: ) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL""" """Gets an image and thumbnail URL"""
try: try:
image_url = ApiDependencies.invoker.services.images.get_url( image_url = ApiDependencies.invoker.services.images.get_url(image_name)
image_origin, image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_url( thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name, thumbnail=True image_name, thumbnail=True
) )
return ImageUrlsDTO( return ImageUrlsDTO(
image_origin=image_origin,
image_name=image_name, image_name=image_name,
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,

View File

@ -193,9 +193,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image return image
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image( raw_image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(raw_image)
@ -216,10 +214,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
) )
"""Builds an ImageOutput and its ImageField""" """Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField( processed_image_field = ImageField(image_name=image_dto.image_name)
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
)
return ImageOutput( return ImageOutput(
image=processed_image_field, image=processed_image_field,
# width=processed_image.width, # width=processed_image.width,

View File

@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name mask = context.services.images.get_pil_image(self.mask.image_name)
)
mask = context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
# Convert to cv image/mask # Convert to cv image/mask
# TODO: consider making these utility functions # TODO: consider making these utility functions
@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -86,9 +86,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# loading controlnet image (currently requires pre-processed image) # loading controlnet image (currently requires pre-processed image)
control_image = ( control_image = (
None if self.control_image is None None if self.control_image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(self.control_image.image_name)
self.control_image.image_origin, self.control_image.image_name
)
) )
# loading controlnet model # loading controlnet model
if (self.control_model is None or self.control_model==''): if (self.control_model is None or self.control_model==''):
@ -128,10 +126,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -169,9 +164,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image = ( image = (
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
) )
if self.fit: if self.fit:
@ -209,10 +202,7 @@ class ImageToImageInvocation(TextToImageInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -282,14 +272,12 @@ class InpaintInvocation(ImageToImageInvocation):
image = ( image = (
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
) )
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name) else context.services.images.get_pil_image(self.mask.image_name)
) )
# Handle invalid model parameter # Handle invalid model parameter
@ -325,10 +313,7 @@ class InpaintInvocation(ImageToImageInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation):
) )
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=self.image.image_name),
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
width=image.width, width=image.width,
height=image.height, height=image.height,
) )
@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
if image: if image:
image.show() image.show()
# TODO: how to handle failure? # TODO: how to handle failure?
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=self.image.image_name),
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
width=image.width, width=image.width,
height=image.height, height=image.height,
) )
@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
image_crop = Image.new( image_crop = Image.new(
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0) mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image( base_image = context.services.images.get_pil_image(self.base_image.image_name)
self.base_image.image_origin, self.base_image.image_name image = context.services.images.get_pil_image(self.image.image_name)
)
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else ImageOps.invert( else ImageOps.invert(
context.services.images.get_pil_image( context.services.images.get_pil_image(self.mask.image_name)
self.mask.image_origin, self.mask.image_name
)
) )
) )
# TODO: probably shouldn't invert mask here... should user be required to do it? # TODO: probably shouldn't invert mask here... should user be required to do it?
@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
image_mask = image.split()[-1] image_mask = image.split()[-1]
if self.invert: if self.invert:
@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
) )
return MaskOutput( return MaskOutput(
mask=ImageField( mask=ImageField(image_name=image_dto.image_name),
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image( image1 = context.services.images.get_pil_image(self.image1.image_name)
self.image1.image_origin, self.image1.image_name image2 = context.services.images.get_pil_image(self.image2.image_name)
)
image2 = context.services.images.get_pil_image(
self.image2.image_origin, self.image2.image_name
)
multiply_image = ImageChops.multiply(image1, image2) multiply_image = ImageChops.multiply(image1, image2)
@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
channel_image = image.getchannel(self.channel) channel_image = image.getchannel(self.channel)
@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
converted_image = image.convert(self.mode) converted_image = image.convert(self.mode)
@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
blur = ( blur = (
ImageFilter.GaussianBlur(self.radius) ImageFilter.GaussianBlur(self.radius)
@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor) width = int(image.width * self.scale_factor)
@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max image_arr = image_arr * (self.max - self.min) + self.max
@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = ( image_arr = (
@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple()) solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
infilled = tile_fill_missing( infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size image.copy(), seed=self.seed, tile_size=self.tile_size
@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
if PatchMatch.patchmatch_available(): if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy()) infilled = infill_patchmatch(image.copy())
@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -321,8 +321,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device) torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_origin, input_image = context.services.images.get_pil_image(control_image_field.image_name)
control_image_field.image_name)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -502,10 +501,7 @@ class LatentsToImageInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
@ -601,9 +597,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# image = context.services.images.get( # image = context.services.images.get(
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# ) # )
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
# TODO: this only really needs the vae # TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model) model_info = choose_model(context.services.model_manager, self.model)

View File

@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
upscale=None, upscale=None,
@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(self.image.image_name)
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
upscale=(self.level, self.strength), upscale=(self.level, self.strength),
@ -53,10 +51,7 @@ class UpscaleInvocation(BaseInvocation):
) )
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(image_name=image_dto.image_name),
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )

View File

@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel): class ImageField(BaseModel):
"""An image field used for passing image objects between invocations""" """An image field used for passing image objects between invocations"""
image_origin: ResourceOrigin = Field(
default=ResourceOrigin.INTERNAL, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image") image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config: class Config:
schema_extra = {"required": ["image_origin", "image_name"]} schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel): class ColorField(BaseModel):

View File

@ -40,14 +40,12 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
@abstractmethod @abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: def get(self, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image.""" """Retrieves an image as PIL Image."""
pass pass
@abstractmethod @abstractmethod
def get_path( def get_path(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or thumbnail.""" """Gets the internal path to an image or thumbnail."""
pass pass
@ -62,7 +60,6 @@ class ImageFileStorageBase(ABC):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
@ -71,7 +68,7 @@ class ImageFileStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists).""" """Deletes an image and its thumbnail (if one exists)."""
pass pass
@ -93,17 +90,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
Path(output_folder).mkdir(parents=True, exist_ok=True) Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath? # TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_origin in ResourceOrigin: Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True)
Path(os.path.join(output_folder, image_origin)).mkdir( Path(os.path.join(output_folder, "thumbnails")).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: def get(self, image_name: str) -> PILImageType:
try: try:
image_path = self.get_path(image_origin, image_name) image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path) cache_item = self.__get_cache(image_path)
if cache_item: if cache_item:
return cache_item return cache_item
@ -117,13 +111,12 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
image_path = self.get_path(image_origin, image_name) image_path = self.get_path(image_name)
if metadata is not None: if metadata is not None:
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
@ -133,7 +126,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG") image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path) thumbnail_image.save(thumbnail_path)
@ -142,10 +135,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e: except Exception as e:
raise ImageFileSaveException from e raise ImageFileSaveException from e
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: def delete(self, image_name: str) -> None:
try: try:
basename = os.path.basename(image_name) basename = os.path.basename(image_name)
image_path = self.get_path(image_origin, basename) image_path = self.get_path(basename)
if os.path.exists(image_path): if os.path.exists(image_path):
send2trash(image_path) send2trash(image_path)
@ -153,7 +146,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
del self.__cache[image_path] del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, True) thumbnail_path = self.get_path(thumbnail_name, True)
if os.path.exists(thumbnail_path): if os.path.exists(thumbnail_path):
send2trash(thumbnail_path) send2trash(thumbnail_path)
@ -163,19 +156,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
raise ImageFileDeleteException from e raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
def get_path( def get_path(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans # strip out any relative path shenanigans
basename = os.path.basename(image_name) basename = os.path.basename(image_name)
if thumbnail: if thumbnail:
thumbnail_name = get_thumbnail_name(basename) thumbnail_name = get_thumbnail_name(basename)
path = os.path.join( path = os.path.join(
self.__output_folder, image_origin, "thumbnails", thumbnail_name self.__output_folder,
"thumbnails",
thumbnail_name,
) )
else: else:
path = os.path.join(self.__output_folder, image_origin, basename) path = os.path.join(self.__output_folder, basename)
abspath = os.path.abspath(path) abspath = os.path.abspath(path)

View File

@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import (
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]): class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results""" """Offset-paginated results"""
@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method # TODO: Implement an `update()` method
@abstractmethod @abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: def get(self, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC):
def update( def update(
self, self,
image_name: str, image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> None: ) -> None:
"""Updates an image record.""" """Updates an image record."""
@ -89,7 +89,7 @@ class ImageRecordStorageBase(ABC):
# TODO: The database has a nullable `deleted_at` column, currently unused. # TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage. # Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod @abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: def delete(self, image_name: str) -> None:
"""Deletes an image record.""" """Deletes an image record."""
pass pass
@ -196,9 +196,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
def get( def get(self, image_name: str) -> Union[ImageRecord, None]:
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
try: try:
self._lock.acquire() self._lock.acquire()
@ -225,7 +223,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def update( def update(
self, self,
image_name: str, image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> None: ) -> None:
try: try:
@ -294,9 +291,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
if categories is not None: if categories is not None:
## Convert the enum values to unique list of strings ## Convert the enum values to unique list of strings
category_strings = list( category_strings = list(map(lambda c: c.value, set(categories)))
map(lambda c: c.value, set(categories))
)
# Create the correct length of placeholders # Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings)) placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n" query_conditions += f"AND image_category IN ( {placeholders} )\n"
@ -337,7 +332,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
items=images, offset=offset, limit=limit, total=count items=images, offset=offset, limit=limit, total=count
) )
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: def delete(self, image_name: str) -> None:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -57,7 +57,6 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def update( def update(
self, self,
image_origin: ResourceOrigin,
image_name: str, image_name: str,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: def get_pil_image(self, image_name: str) -> PILImageType:
"""Gets an image as a PIL image.""" """Gets an image as a PIL image."""
pass pass
@abstractmethod @abstractmethod
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: def get_record(self, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod @abstractmethod
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image DTO.""" """Gets an image DTO."""
pass pass
@abstractmethod @abstractmethod
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str: def get_path(self, image_name: str) -> str:
"""Gets an image's path.""" """Gets an image's path."""
pass pass
@ -90,9 +89,7 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_url( def get_url(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets an image's or thumbnail's URL.""" """Gets an image's or thumbnail's URL."""
pass pass
@ -109,7 +106,7 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str): def delete(self, image_name: str):
"""Deletes an image.""" """Deletes an image."""
pass pass
@ -206,16 +203,13 @@ class ImageService(ImageServiceABC):
) )
self._services.files.save( self._services.files.save(
image_origin=image_origin,
image_name=image_name, image_name=image_name,
image=image, image=image,
metadata=metadata, metadata=metadata,
) )
image_url = self._services.urls.get_image_url(image_origin, image_name) image_url = self._services.urls.get_image_url(image_name)
thumbnail_url = self._services.urls.get_image_url( thumbnail_url = self._services.urls.get_image_url(image_name, True)
image_origin, image_name, True
)
return ImageDTO( return ImageDTO(
# Non-nullable fields # Non-nullable fields
@ -249,13 +243,12 @@ class ImageService(ImageServiceABC):
def update( def update(
self, self,
image_origin: ResourceOrigin,
image_name: str, image_name: str,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
try: try:
self._services.records.update(image_name, image_origin, changes) self._services.records.update(image_name, changes)
return self.get_dto(image_origin, image_name) return self.get_dto(image_name)
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to update image record") self._services.logger.error("Failed to update image record")
raise raise
@ -263,9 +256,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record") self._services.logger.error("Problem updating image record")
raise e raise e
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: def get_pil_image(self, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_origin, image_name) return self._services.files.get(image_name)
except ImageFileNotFoundException: except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self._services.logger.error("Failed to get image file")
raise raise
@ -273,9 +266,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file") self._services.logger.error("Problem getting image file")
raise e raise e
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: def get_record(self, image_name: str) -> ImageRecord:
try: try:
return self._services.records.get(image_origin, image_name) return self._services.records.get(image_name)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")
raise raise
@ -283,14 +276,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record") self._services.logger.error("Problem getting image record")
raise e raise e
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: def get_dto(self, image_name: str) -> ImageDTO:
try: try:
image_record = self._services.records.get(image_origin, image_name) image_record = self._services.records.get(image_name)
image_dto = image_record_to_dto( image_dto = image_record_to_dto(
image_record, image_record,
self._services.urls.get_image_url(image_origin, image_name), self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_origin, image_name, True), self._services.urls.get_image_url(image_name, True),
) )
return image_dto return image_dto
@ -301,11 +294,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO") self._services.logger.error("Problem getting image DTO")
raise e raise e
def get_path( def get_path(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try: try:
return self._services.files.get_path(image_origin, image_name, thumbnail) return self._services.files.get_path(image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
@ -317,11 +308,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem validating image path") self._services.logger.error("Problem validating image path")
raise e raise e
def get_url( def get_url(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try: try:
return self._services.urls.get_image_url(image_origin, image_name, thumbnail) return self._services.urls.get_image_url(image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
@ -347,10 +336,8 @@ class ImageService(ImageServiceABC):
map( map(
lambda r: image_record_to_dto( lambda r: image_record_to_dto(
r, r,
self._services.urls.get_image_url(r.image_origin, r.image_name), self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url( self._services.urls.get_image_url(r.image_name, True),
r.image_origin, r.image_name, True
),
), ),
results.items, results.items,
) )
@ -366,10 +353,10 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting paginated image DTOs") self._services.logger.error("Problem getting paginated image DTOs")
raise e raise e
def delete(self, image_origin: ResourceOrigin, image_name: str): def delete(self, image_name: str):
try: try:
self._services.files.delete(image_origin, image_name) self._services.files.delete(image_name)
self._services.records.delete(image_origin, image_name) self._services.records.delete(image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record") self._services.logger.error(f"Failed to delete image record")
raise raise

View File

@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel):
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_url: str = Field(description="The URL of the image.") image_url: str = Field(description="The URL of the image.")
"""The URL of the image.""" """The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.") thumbnail_url: str = Field(description="The URL of the image's thumbnail.")

View File

@ -1,17 +1,12 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC): class UrlServiceBase(ABC):
"""Responsible for building URLs for resources.""" """Responsible for building URLs for resources."""
@abstractmethod @abstractmethod
def get_image_url( def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the URL for an image or thumbnail.""" """Gets the URL for an image or thumbnail."""
pass pass
@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"): def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url self._base_url = base_url
def get_image_url( def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
image_basename = os.path.basename(image_name) image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py # These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail: if thumbnail:
return ( return f"{self._base_url}/images/{image_basename}/thumbnail"
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_origin.value}/{image_basename}" return f"{self._base_url}/images/{image_basename}"