feat(nodes): remove image_origin from most places

- remove `image_origin` from most places where we interact with images
- consolidate image file storage into a single `images/` dir

Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads.

It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images.

The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful.

So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
This commit is contained in:
psychedelicious 2023-06-14 21:40:09 +10:00
parent 1e08d865c9
commit a1773197e9
15 changed files with 124 additions and 299 deletions

View File

@ -70,27 +70,25 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
@images_router.delete("/{image_name}", operation_id="delete_image")
async def delete_image(
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
try:
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
ApiDependencies.invoker.services.images.delete(image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@images_router.patch(
"/{image_origin}/{image_name}",
"/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
@ -99,32 +97,29 @@ async def update_image(
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(
image_origin, image_name, image_changes
)
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get(
"/{image_origin}/{image_name}/metadata",
"/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageDTO,
)
async def get_image_metadata(
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
return ApiDependencies.invoker.services.images.get_dto(image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_origin}/{image_name}",
"/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
@ -136,15 +131,12 @@ async def get_image_metadata(
},
)
async def get_image_full(
image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
path = ApiDependencies.invoker.services.images.get_path(image_name)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@ -160,7 +152,7 @@ async def get_image_full(
@images_router.get(
"/{image_origin}/{image_name}/thumbnail",
"/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
@ -172,14 +164,13 @@ async def get_image_full(
},
)
async def get_image_thumbnail(
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_origin, image_name, thumbnail=True
image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@ -192,25 +183,21 @@ async def get_image_thumbnail(
@images_router.get(
"/{image_origin}/{image_name}/urls",
"/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
async def get_image_urls(
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name
)
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_origin, image_name, thumbnail=True
image_name, thumbnail=True
)
return ImageUrlsDTO(
image_origin=image_origin,
image_name=image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,

View File

@ -193,9 +193,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
return image
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
raw_image = context.services.images.get_pil_image(self.image.image_name)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@ -216,10 +214,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
)
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
)
processed_image_field = ImageField(image_name=image_dto.image_name)
return ImageOutput(
image=processed_image_field,
# width=processed_image.width,

View File

@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
mask = context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
mask = context.services.images.get_pil_image(self.mask.image_name)
# Convert to cv image/mask
# TODO: consider making these utility functions
@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -86,9 +86,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get_pil_image(
self.control_image.image_origin, self.control_image.image_name
)
else context.services.images.get_pil_image(self.control_image.image_name)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
@ -128,10 +126,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -169,9 +164,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
else context.services.images.get_pil_image(self.image.image_name)
)
if self.fit:
@ -209,10 +202,7 @@ class ImageToImageInvocation(TextToImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -282,14 +272,12 @@ class InpaintInvocation(ImageToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
else context.services.images.get_pil_image(self.image.image_name)
)
mask = (
None
if self.mask is None
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
else context.services.images.get_pil_image(self.mask.image_name)
)
# Handle invalid model parameter
@ -325,10 +313,7 @@ class InpaintInvocation(ImageToImageInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation):
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if image:
image.show()
# TODO: how to handle failure?
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_origin=self.image.image_origin,
),
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_crop = Image.new(
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(
self.base_image.image_origin, self.base_image.image_name
)
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
base_image = context.services.images.get_pil_image(self.base_image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get_pil_image(
self.mask.image_origin, self.mask.image_name
)
context.services.images.get_pil_image(self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_mask = image.split()[-1]
if self.invert:
@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
)
return MaskOutput(
mask=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
mask=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(
self.image1.image_origin, self.image1.image_name
)
image2 = context.services.images.get_pil_image(
self.image2.image_origin, self.image2.image_name
)
image1 = context.services.images.get_pil_image(self.image1.image_name)
image2 = context.services.images.get_pil_image(self.image2.image_name)
multiply_image = ImageChops.multiply(image1, image2)
@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
channel_image = image.getchannel(self.channel)
@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
converted_image = image.convert(self.mode)
@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
blur = (
ImageFilter.GaussianBlur(self.radius)
@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
image_arr = image_arr * (self.max - self.min) + self.max
@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = (
@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy())
@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -321,8 +321,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
control_image_field.image_name)
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@ -502,10 +501,7 @@ class LatentsToImageInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
@ -601,9 +597,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)

View File

@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=None,
@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
image = context.services.images.get_pil_image(self.image.image_name)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
upscale=(self.level, self.strength),
@ -53,10 +51,7 @@ class UpscaleInvocation(BaseInvocation):
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_origin: ResourceOrigin = Field(
default=ResourceOrigin.INTERNAL, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_origin", "image_name"]}
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):

View File

@ -40,14 +40,12 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get(self, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
@ -62,7 +60,6 @@ class ImageFileStorageBase(ABC):
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
@ -71,7 +68,7 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
@ -93,17 +90,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_origin in ResourceOrigin:
Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True)
Path(os.path.join(output_folder, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get(self, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_origin, image_name)
image_path = self.get_path(image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
@ -117,13 +111,12 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_origin, image_name)
image_path = self.get_path(image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
@ -133,7 +126,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
@ -142,10 +135,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e:
raise ImageFileSaveException from e
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(image_origin, basename)
image_path = self.get_path(basename)
if os.path.exists(image_path):
send2trash(image_path)
@ -153,7 +146,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
thumbnail_path = self.get_path(thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
@ -163,19 +156,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_origin, "thumbnails", thumbnail_name
self.__output_folder,
"thumbnails",
thumbnail_name,
)
else:
path = os.path.join(self.__output_folder, image_origin, basename)
path = os.path.join(self.__output_folder, basename)
abspath = os.path.abspath(path)

View File

@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import (
T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results"""
@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method
@abstractmethod
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC):
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
@ -89,7 +89,7 @@ class ImageRecordStorageBase(ABC):
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
"""Deletes an image record."""
pass
@ -196,9 +196,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
def get(
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
def get(self, image_name: str) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
@ -225,7 +223,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def update(
self,
image_name: str,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
try:
@ -294,9 +291,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
if categories is not None:
## Convert the enum values to unique list of strings
category_strings = list(
map(lambda c: c.value, set(categories))
)
category_strings = list(map(lambda c: c.value, set(categories)))
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
@ -337,7 +332,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
items=images, offset=offset, limit=limit, total=count
)
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
def delete(self, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(

View File

@ -57,7 +57,6 @@ class ImageServiceABC(ABC):
@abstractmethod
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get_pil_image(self, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get_record(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
def get_path(self, image_name: str) -> str:
"""Gets an image's path."""
pass
@ -90,9 +89,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@ -109,7 +106,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def delete(self, image_origin: ResourceOrigin, image_name: str):
def delete(self, image_name: str):
"""Deletes an image."""
pass
@ -206,16 +203,13 @@ class ImageService(ImageServiceABC):
)
self._services.files.save(
image_origin=image_origin,
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url(
image_origin, image_name, True
)
image_url = self._services.urls.get_image_url(image_name)
thumbnail_url = self._services.urls.get_image_url(image_name, True)
return ImageDTO(
# Non-nullable fields
@ -249,13 +243,12 @@ class ImageService(ImageServiceABC):
def update(
self,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, image_origin, changes)
return self.get_dto(image_origin, image_name)
self._services.records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@ -263,9 +256,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_origin, image_name)
return self._services.files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
@ -273,9 +266,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_origin, image_name)
return self._services.records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
@ -283,14 +276,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_origin, image_name)
image_record = self._services.records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_origin, image_name),
self._services.urls.get_image_url(image_origin, image_name, True),
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
)
return image_dto
@ -301,11 +294,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.files.get_path(image_origin, image_name, thumbnail)
return self._services.files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@ -317,11 +308,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem validating image path")
raise e
def get_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
return self._services.urls.get_image_url(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@ -347,10 +336,8 @@ class ImageService(ImageServiceABC):
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url(
r.image_origin, r.image_name, True
),
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
),
results.items,
)
@ -366,10 +353,10 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_origin: ResourceOrigin, image_name: str):
def delete(self, image_name: str):
try:
self._services.files.delete(image_origin, image_name)
self._services.records.delete(image_origin, image_name)
self._services.files.delete(image_name)
self._services.records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise

View File

@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel):
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_url: str = Field(description="The URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")

View File

@ -1,17 +1,12 @@
import os
from abc import ABC, abstractmethod
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC):
"""Responsible for building URLs for resources."""
@abstractmethod
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
def get_image_url(
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
return f"{self._base_url}/images/{image_basename}"