mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): remove image_origin
from most places
- remove `image_origin` from most places where we interact with images - consolidate image file storage into a single `images/` dir Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads. It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images. The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful. So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
This commit is contained in:
parent
1e08d865c9
commit
a1773197e9
@ -70,27 +70,25 @@ async def upload_image(
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
|
||||
@images_router.delete("/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
) -> None:
|
||||
"""Deletes an image"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
|
||||
ApiDependencies.invoker.services.images.delete(image_name)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@images_router.patch(
|
||||
"/{image_origin}/{image_name}",
|
||||
"/{image_name}",
|
||||
operation_id="update_image",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def update_image(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(
|
||||
description="The changes to apply to the image"
|
||||
@ -99,32 +97,29 @@ async def update_image(
|
||||
"""Updates an image"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.update(
|
||||
image_origin, image_name, image_changes
|
||||
)
|
||||
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/metadata",
|
||||
"/{image_name}/metadata",
|
||||
operation_id="get_image_metadata",
|
||||
response_model=ImageDTO,
|
||||
)
|
||||
async def get_image_metadata(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
|
||||
image_name: str = Path(description="The name of image to get"),
|
||||
) -> ImageDTO:
|
||||
"""Gets an image's metadata"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
|
||||
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}",
|
||||
"/{image_name}",
|
||||
operation_id="get_image_full",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@ -136,15 +131,12 @@ async def get_image_metadata(
|
||||
},
|
||||
)
|
||||
async def get_image_full(
|
||||
image_origin: ResourceOrigin = Path(
|
||||
description="The type of full-resolution image file to get"
|
||||
),
|
||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a full-resolution image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name)
|
||||
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
@ -160,7 +152,7 @@ async def get_image_full(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/thumbnail",
|
||||
"/{image_name}/thumbnail",
|
||||
operation_id="get_image_thumbnail",
|
||||
response_class=Response,
|
||||
responses={
|
||||
@ -172,14 +164,13 @@ async def get_image_full(
|
||||
},
|
||||
)
|
||||
async def get_image_thumbnail(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
|
||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a thumbnail image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_origin, image_name, thumbnail=True
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
@ -192,25 +183,21 @@ async def get_image_thumbnail(
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/{image_origin}/{image_name}/urls",
|
||||
"/{image_name}/urls",
|
||||
operation_id="get_image_urls",
|
||||
response_model=ImageUrlsDTO,
|
||||
)
|
||||
async def get_image_urls(
|
||||
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
|
||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||
) -> ImageUrlsDTO:
|
||||
"""Gets an image and thumbnail URL"""
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_origin, image_name
|
||||
)
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_origin, image_name, thumbnail=True
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
return ImageUrlsDTO(
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
|
@ -193,9 +193,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raw_image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
@ -216,10 +214,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
)
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
return ImageOutput(
|
||||
image=processed_image_field,
|
||||
# width=processed_image.width,
|
||||
|
@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
mask = context.services.images.get_pil_image(
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -86,9 +86,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if self.control_image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.control_image.image_origin, self.control_image.image_name
|
||||
)
|
||||
else context.services.images.get_pil_image(self.control_image.image_name)
|
||||
)
|
||||
# loading controlnet model
|
||||
if (self.control_model is None or self.control_model==''):
|
||||
@ -128,10 +126,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -169,9 +164,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
|
||||
if self.fit:
|
||||
@ -209,10 +202,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -282,14 +272,12 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
|
||||
else context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
@ -325,10 +313,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation):
|
||||
)
|
||||
# fmt: on
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
if image:
|
||||
image.show()
|
||||
|
||||
# TODO: how to handle failure?
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=self.image.image_name,
|
||||
image_origin=self.image.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
width=image.width,
|
||||
height=image.height,
|
||||
)
|
||||
@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_crop = Image.new(
|
||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||
@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
base_image = context.services.images.get_pil_image(
|
||||
self.base_image.image_origin, self.base_image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = (
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
context.services.images.get_pil_image(
|
||||
self.mask.image_origin, self.mask.image_name
|
||||
)
|
||||
context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_mask = image.split()[-1]
|
||||
if self.invert:
|
||||
@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return MaskOutput(
|
||||
mask=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
mask=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image1 = context.services.images.get_pil_image(
|
||||
self.image1.image_origin, self.image1.image_name
|
||||
)
|
||||
image2 = context.services.images.get_pil_image(
|
||||
self.image2.image_origin, self.image2.image_name
|
||||
)
|
||||
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||
image2 = context.services.images.get_pil_image(self.image2.image_name)
|
||||
|
||||
multiply_image = ImageChops.multiply(image1, image2)
|
||||
|
||||
@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
channel_image = image.getchannel(self.channel)
|
||||
|
||||
@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
converted_image = image.convert(self.mode)
|
||||
|
||||
@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
blur = (
|
||||
ImageFilter.GaussianBlur(self.radius)
|
||||
@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
width = int(image.width * self.scale_factor)
|
||||
@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||
image_arr = image_arr * (self.max - self.min) + self.max
|
||||
@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
# fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||
image_arr = (
|
||||
@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(
|
||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||
@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(image.copy())
|
||||
@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -321,8 +321,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
torch_dtype=model.unet.dtype).to(model.device)
|
||||
control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
|
||||
control_image_field.image_name)
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@ -502,10 +501,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@ -601,9 +597,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
# image = context.services.images.get(
|
||||
# self.image.image_type, self.image.image_name
|
||||
# )
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
|
@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=None,
|
||||
@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(
|
||||
self.image.image_origin, self.image.image_name
|
||||
)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
results = context.services.restoration.upscale_and_reconstruct(
|
||||
image_list=[[image, 0]],
|
||||
upscale=(self.level, self.strength),
|
||||
@ -53,10 +51,7 @@ class UpscaleInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=image_dto.image_name,
|
||||
image_origin=image_dto.image_origin,
|
||||
),
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError):
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_origin: ResourceOrigin = Field(
|
||||
default=ResourceOrigin.INTERNAL, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_origin", "image_name"]}
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
|
@ -40,14 +40,12 @@ class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@ -62,7 +60,6 @@ class ImageFileStorageBase(ABC):
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
@ -71,7 +68,7 @@ class ImageFileStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
@ -93,17 +90,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_origin in ResourceOrigin:
|
||||
Path(os.path.join(output_folder, image_origin)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True)
|
||||
Path(os.path.join(output_folder, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
image_path = self.get_path(image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
@ -117,13 +111,12 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
@ -133,7 +126,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image.save(image_path, "PNG")
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
|
||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
@ -142,10 +135,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
except Exception as e:
|
||||
raise ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_origin, basename)
|
||||
image_path = self.get_path(basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
@ -153,7 +146,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
|
||||
thumbnail_path = self.get_path(thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
@ -163,19 +156,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
raise ImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(basename)
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_origin, "thumbnails", thumbnail_name
|
||||
self.__output_folder,
|
||||
"thumbnails",
|
||||
thumbnail_name,
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_origin, basename)
|
||||
path = os.path.join(self.__output_folder, basename)
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
|
@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import (
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||
"""Offset-paginated results"""
|
||||
|
||||
@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC):
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC):
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
"""Updates an image record."""
|
||||
@ -89,7 +89,7 @@ class ImageRecordStorageBase(ABC):
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@ -196,9 +196,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
def get(
|
||||
self, image_origin: ResourceOrigin, image_name: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
def get(self, image_name: str) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
@ -225,7 +223,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
@ -294,9 +291,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
if categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
category_strings = list(
|
||||
map(lambda c: c.value, set(categories))
|
||||
)
|
||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
@ -337,7 +332,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
|
@ -57,7 +57,6 @@ class ImageServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
|
||||
def get_path(self, image_name: str) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@ -90,9 +89,7 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's or thumbnail's URL."""
|
||||
pass
|
||||
|
||||
@ -109,7 +106,7 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
def delete(self, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
@ -206,16 +203,13 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_origin, image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(
|
||||
image_origin, image_name, True
|
||||
)
|
||||
image_url = self._services.urls.get_image_url(image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(image_name, True)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
@ -249,13 +243,12 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.records.update(image_name, image_origin, changes)
|
||||
return self.get_dto(image_origin, image_name)
|
||||
self._services.records.update(image_name, changes)
|
||||
return self.get_dto(image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
@ -263,9 +256,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem updating image record")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_origin, image_name)
|
||||
return self._services.files.get(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
@ -273,9 +266,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_origin, image_name)
|
||||
return self._services.records.get(image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
@ -283,14 +276,14 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_origin, image_name)
|
||||
image_record = self._services.records.get(image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_origin, image_name),
|
||||
self._services.urls.get_image_url(image_origin, image_name, True),
|
||||
self._services.urls.get_image_url(image_name),
|
||||
self._services.urls.get_image_url(image_name, True),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@ -301,11 +294,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_origin, image_name, thumbnail)
|
||||
return self._services.files.get_path(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
@ -317,11 +308,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
|
||||
return self._services.urls.get_image_url(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
@ -347,10 +336,8 @@ class ImageService(ImageServiceABC):
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_origin, r.image_name),
|
||||
self._services.urls.get_image_url(
|
||||
r.image_origin, r.image_name, True
|
||||
),
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
@ -366,10 +353,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
def delete(self, image_name: str):
|
||||
try:
|
||||
self._services.files.delete(image_origin, image_name)
|
||||
self._services.records.delete(image_origin, image_name)
|
||||
self._services.files.delete(image_name)
|
||||
self._services.records.delete(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
|
@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel):
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
|
@ -1,17 +1,12 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||
|
||||
|
||||
class UrlServiceBase(ABC):
|
||||
"""Responsible for building URLs for resources."""
|
||||
|
||||
@abstractmethod
|
||||
def get_image_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase):
|
||||
def __init__(self, base_url: str = "api/v1"):
|
||||
self._base_url = base_url
|
||||
|
||||
def get_image_url(
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return (
|
||||
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
|
||||
)
|
||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
|
||||
return f"{self._base_url}/images/{image_basename}"
|
||||
|
Loading…
Reference in New Issue
Block a user