mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): remove image_origin
from most places
- remove `image_origin` from most places where we interact with images - consolidate image file storage into a single `images/` dir Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads. It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images. The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful. So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
This commit is contained in:
@ -70,27 +70,25 @@ async def upload_image(
|
|||||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||||
|
|
||||||
|
|
||||||
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
|
@images_router.delete("/{image_name}", operation_id="delete_image")
|
||||||
async def delete_image(
|
async def delete_image(
|
||||||
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
|
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
image_name: str = Path(description="The name of the image to delete"),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Deletes an image"""
|
"""Deletes an image"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
|
ApiDependencies.invoker.services.images.delete(image_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: Does this need any exception handling at all?
|
# TODO: Does this need any exception handling at all?
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@images_router.patch(
|
@images_router.patch(
|
||||||
"/{image_origin}/{image_name}",
|
"/{image_name}",
|
||||||
operation_id="update_image",
|
operation_id="update_image",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
async def update_image(
|
async def update_image(
|
||||||
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
|
|
||||||
image_name: str = Path(description="The name of the image to update"),
|
image_name: str = Path(description="The name of the image to update"),
|
||||||
image_changes: ImageRecordChanges = Body(
|
image_changes: ImageRecordChanges = Body(
|
||||||
description="The changes to apply to the image"
|
description="The changes to apply to the image"
|
||||||
@ -99,32 +97,29 @@ async def update_image(
|
|||||||
"""Updates an image"""
|
"""Updates an image"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.update(
|
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
|
||||||
image_origin, image_name, image_changes
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_origin}/{image_name}/metadata",
|
"/{image_name}/metadata",
|
||||||
operation_id="get_image_metadata",
|
operation_id="get_image_metadata",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
async def get_image_metadata(
|
async def get_image_metadata(
|
||||||
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
|
|
||||||
image_name: str = Path(description="The name of image to get"),
|
image_name: str = Path(description="The name of image to get"),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Gets an image's metadata"""
|
"""Gets an image's metadata"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
|
return ApiDependencies.invoker.services.images.get_dto(image_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_origin}/{image_name}",
|
"/{image_name}",
|
||||||
operation_id="get_image_full",
|
operation_id="get_image_full",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@ -136,15 +131,12 @@ async def get_image_metadata(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_image_full(
|
async def get_image_full(
|
||||||
image_origin: ResourceOrigin = Path(
|
|
||||||
description="The type of full-resolution image file to get"
|
|
||||||
),
|
|
||||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||||
) -> FileResponse:
|
) -> FileResponse:
|
||||||
"""Gets a full-resolution image file"""
|
"""Gets a full-resolution image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
|
path = ApiDependencies.invoker.services.images.get_path(image_name)
|
||||||
|
|
||||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
@ -160,7 +152,7 @@ async def get_image_full(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_origin}/{image_name}/thumbnail",
|
"/{image_name}/thumbnail",
|
||||||
operation_id="get_image_thumbnail",
|
operation_id="get_image_thumbnail",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@ -172,14 +164,13 @@ async def get_image_full(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_image_thumbnail(
|
async def get_image_thumbnail(
|
||||||
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
|
|
||||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||||
) -> FileResponse:
|
) -> FileResponse:
|
||||||
"""Gets a thumbnail image file"""
|
"""Gets a thumbnail image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
image_origin, image_name, thumbnail=True
|
image_name, thumbnail=True
|
||||||
)
|
)
|
||||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
@ -192,25 +183,21 @@ async def get_image_thumbnail(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_origin}/{image_name}/urls",
|
"/{image_name}/urls",
|
||||||
operation_id="get_image_urls",
|
operation_id="get_image_urls",
|
||||||
response_model=ImageUrlsDTO,
|
response_model=ImageUrlsDTO,
|
||||||
)
|
)
|
||||||
async def get_image_urls(
|
async def get_image_urls(
|
||||||
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
|
|
||||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||||
) -> ImageUrlsDTO:
|
) -> ImageUrlsDTO:
|
||||||
"""Gets an image and thumbnail URL"""
|
"""Gets an image and thumbnail URL"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image_url = ApiDependencies.invoker.services.images.get_url(
|
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||||
image_origin, image_name
|
|
||||||
)
|
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||||
image_origin, image_name, thumbnail=True
|
image_name, thumbnail=True
|
||||||
)
|
)
|
||||||
return ImageUrlsDTO(
|
return ImageUrlsDTO(
|
||||||
image_origin=image_origin,
|
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
|
@ -193,9 +193,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
raw_image = context.services.images.get_pil_image(
|
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
processed_image = self.run_processor(raw_image)
|
processed_image = self.run_processor(raw_image)
|
||||||
|
|
||||||
@ -216,10 +214,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
"""Builds an ImageOutput and its ImageField"""
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
processed_image_field = ImageField(
|
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
)
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=processed_image_field,
|
image=processed_image_field,
|
||||||
# width=processed_image.width,
|
# width=processed_image.width,
|
||||||
|
@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||||
)
|
|
||||||
mask = context.services.images.get_pil_image(
|
|
||||||
self.mask.image_origin, self.mask.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to cv image/mask
|
# Convert to cv image/mask
|
||||||
# TODO: consider making these utility functions
|
# TODO: consider making these utility functions
|
||||||
@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
@ -86,9 +86,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
# loading controlnet image (currently requires pre-processed image)
|
# loading controlnet image (currently requires pre-processed image)
|
||||||
control_image = (
|
control_image = (
|
||||||
None if self.control_image is None
|
None if self.control_image is None
|
||||||
else context.services.images.get_pil_image(
|
else context.services.images.get_pil_image(self.control_image.image_name)
|
||||||
self.control_image.image_origin, self.control_image.image_name
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# loading controlnet model
|
# loading controlnet model
|
||||||
if (self.control_model is None or self.control_model==''):
|
if (self.control_model is None or self.control_model==''):
|
||||||
@ -128,10 +126,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -169,9 +164,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
if self.image is None
|
if self.image is None
|
||||||
else context.services.images.get_pil_image(
|
else context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.fit:
|
if self.fit:
|
||||||
@ -209,10 +202,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -282,14 +272,12 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
if self.image is None
|
if self.image is None
|
||||||
else context.services.images.get_pil_image(
|
else context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
mask = (
|
mask = (
|
||||||
None
|
None
|
||||||
if self.mask is None
|
if self.mask is None
|
||||||
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
|
else context.services.images.get_pil_image(self.mask.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
@ -325,10 +313,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=self.image.image_name),
|
||||||
image_name=self.image.image_name,
|
|
||||||
image_origin=self.image.image_origin,
|
|
||||||
),
|
|
||||||
width=image.width,
|
width=image.width,
|
||||||
height=image.height,
|
height=image.height,
|
||||||
)
|
)
|
||||||
@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
if image:
|
if image:
|
||||||
image.show()
|
image.show()
|
||||||
|
|
||||||
# TODO: how to handle failure?
|
# TODO: how to handle failure?
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=self.image.image_name),
|
||||||
image_name=self.image.image_name,
|
|
||||||
image_origin=self.image.image_origin,
|
|
||||||
),
|
|
||||||
width=image.width,
|
width=image.width,
|
||||||
height=image.height,
|
height=image.height,
|
||||||
)
|
)
|
||||||
@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
image_crop = Image.new(
|
image_crop = Image.new(
|
||||||
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||||
@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get_pil_image(
|
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||||
self.base_image.image_origin, self.base_image.image_name
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
)
|
|
||||||
image = context.services.images.get_pil_image(
|
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
mask = (
|
mask = (
|
||||||
None
|
None
|
||||||
if self.mask is None
|
if self.mask is None
|
||||||
else ImageOps.invert(
|
else ImageOps.invert(
|
||||||
context.services.images.get_pil_image(
|
context.services.images.get_pil_image(self.mask.image_name)
|
||||||
self.mask.image_origin, self.mask.image_name
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||||
@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
image_mask = image.split()[-1]
|
image_mask = image.split()[-1]
|
||||||
if self.invert:
|
if self.invert:
|
||||||
@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return MaskOutput(
|
return MaskOutput(
|
||||||
mask=ImageField(
|
mask=ImageField(image_name=image_dto.image_name),
|
||||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image1 = context.services.images.get_pil_image(
|
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
||||||
self.image1.image_origin, self.image1.image_name
|
image2 = context.services.images.get_pil_image(self.image2.image_name)
|
||||||
)
|
|
||||||
image2 = context.services.images.get_pil_image(
|
|
||||||
self.image2.image_origin, self.image2.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
multiply_image = ImageChops.multiply(image1, image2)
|
multiply_image = ImageChops.multiply(image1, image2)
|
||||||
|
|
||||||
@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
channel_image = image.getchannel(self.channel)
|
channel_image = image.getchannel(self.channel)
|
||||||
|
|
||||||
@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
converted_image = image.convert(self.mode)
|
converted_image = image.convert(self.mode)
|
||||||
|
|
||||||
@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
blur = (
|
blur = (
|
||||||
ImageFilter.GaussianBlur(self.radius)
|
ImageFilter.GaussianBlur(self.radius)
|
||||||
@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
width = int(image.width * self.scale_factor)
|
width = int(image.width * self.scale_factor)
|
||||||
@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||||
image_arr = image_arr * (self.max - self.min) + self.max
|
image_arr = image_arr * (self.max - self.min) + self.max
|
||||||
@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||||
image_arr = (
|
image_arr = (
|
||||||
@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||||
@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
infilled = tile_fill_missing(
|
infilled = tile_fill_missing(
|
||||||
image.copy(), seed=self.seed, tile_size=self.tile_size
|
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||||
@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
infilled = infill_patchmatch(image.copy())
|
infilled = infill_patchmatch(image.copy())
|
||||||
@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
@ -321,8 +321,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
torch_dtype=model.unet.dtype).to(model.device)
|
torch_dtype=model.unet.dtype).to(model.device)
|
||||||
control_models.append(control_model)
|
control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
|
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||||
control_image_field.image_name)
|
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@ -502,10 +501,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
@ -601,9 +597,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# image = context.services.images.get(
|
# image = context.services.images.get(
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# )
|
# )
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: this only really needs the vae
|
# TODO: this only really needs the vae
|
||||||
model_info = choose_model(context.services.model_manager, self.model)
|
model_info = choose_model(context.services.model_manager, self.model)
|
||||||
|
@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
upscale=None,
|
upscale=None,
|
||||||
@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
self.image.image_origin, self.image.image_name
|
|
||||||
)
|
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
upscale=(self.level, self.strength),
|
upscale=(self.level, self.strength),
|
||||||
@ -53,10 +51,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
image_name=image_dto.image_name,
|
|
||||||
image_origin=image_dto.image_origin,
|
|
||||||
),
|
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError):
|
|||||||
class ImageField(BaseModel):
|
class ImageField(BaseModel):
|
||||||
"""An image field used for passing image objects between invocations"""
|
"""An image field used for passing image objects between invocations"""
|
||||||
|
|
||||||
image_origin: ResourceOrigin = Field(
|
|
||||||
default=ResourceOrigin.INTERNAL, description="The type of the image"
|
|
||||||
)
|
|
||||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["image_origin", "image_name"]}
|
schema_extra = {"required": ["image_name"]}
|
||||||
|
|
||||||
|
|
||||||
class ColorField(BaseModel):
|
class ColorField(BaseModel):
|
||||||
|
@ -40,14 +40,12 @@ class ImageFileStorageBase(ABC):
|
|||||||
"""Low-level service responsible for storing and retrieving image files."""
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
def get(self, image_name: str) -> PILImageType:
|
||||||
"""Retrieves an image as PIL Image."""
|
"""Retrieves an image as PIL Image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Gets the internal path to an image or thumbnail."""
|
"""Gets the internal path to an image or thumbnail."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -62,7 +60,6 @@ class ImageFileStorageBase(ABC):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
metadata: Optional[ImageMetadata] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
@ -71,7 +68,7 @@ class ImageFileStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
def delete(self, image_name: str) -> None:
|
||||||
"""Deletes an image and its thumbnail (if one exists)."""
|
"""Deletes an image and its thumbnail (if one exists)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -93,17 +90,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||||
for image_origin in ResourceOrigin:
|
Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True)
|
||||||
Path(os.path.join(output_folder, image_origin)).mkdir(
|
Path(os.path.join(output_folder, "thumbnails")).mkdir(
|
||||||
parents=True, exist_ok=True
|
parents=True, exist_ok=True
|
||||||
)
|
)
|
||||||
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
|
|
||||||
parents=True, exist_ok=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
def get(self, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_origin, image_name)
|
image_path = self.get_path(image_name)
|
||||||
cache_item = self.__get_cache(image_path)
|
cache_item = self.__get_cache(image_path)
|
||||||
if cache_item:
|
if cache_item:
|
||||||
return cache_item
|
return cache_item
|
||||||
@ -117,13 +111,12 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
metadata: Optional[ImageMetadata] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_origin, image_name)
|
image_path = self.get_path(image_name)
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
@ -133,7 +126,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
image.save(image_path, "PNG")
|
image.save(image_path, "PNG")
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
|
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
||||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||||
thumbnail_image.save(thumbnail_path)
|
thumbnail_image.save(thumbnail_path)
|
||||||
|
|
||||||
@ -142,10 +135,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImageFileSaveException from e
|
raise ImageFileSaveException from e
|
||||||
|
|
||||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
def delete(self, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
basename = os.path.basename(image_name)
|
basename = os.path.basename(image_name)
|
||||||
image_path = self.get_path(image_origin, basename)
|
image_path = self.get_path(basename)
|
||||||
|
|
||||||
if os.path.exists(image_path):
|
if os.path.exists(image_path):
|
||||||
send2trash(image_path)
|
send2trash(image_path)
|
||||||
@ -153,7 +146,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
|
thumbnail_path = self.get_path(thumbnail_name, True)
|
||||||
|
|
||||||
if os.path.exists(thumbnail_path):
|
if os.path.exists(thumbnail_path):
|
||||||
send2trash(thumbnail_path)
|
send2trash(thumbnail_path)
|
||||||
@ -163,19 +156,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
raise ImageFileDeleteException from e
|
raise ImageFileDeleteException from e
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def get_path(
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
# strip out any relative path shenanigans
|
# strip out any relative path shenanigans
|
||||||
basename = os.path.basename(image_name)
|
basename = os.path.basename(image_name)
|
||||||
|
|
||||||
if thumbnail:
|
if thumbnail:
|
||||||
thumbnail_name = get_thumbnail_name(basename)
|
thumbnail_name = get_thumbnail_name(basename)
|
||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
self.__output_folder, image_origin, "thumbnails", thumbnail_name
|
self.__output_folder,
|
||||||
|
"thumbnails",
|
||||||
|
thumbnail_name,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = os.path.join(self.__output_folder, image_origin, basename)
|
path = os.path.join(self.__output_folder, basename)
|
||||||
|
|
||||||
abspath = os.path.abspath(path)
|
abspath = os.path.abspath(path)
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import (
|
|||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||||
"""Offset-paginated results"""
|
"""Offset-paginated results"""
|
||||||
|
|
||||||
@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
# TODO: Implement an `update()` method
|
# TODO: Implement an `update()` method
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
def get(self, image_name: str) -> ImageRecord:
|
||||||
"""Gets an image record."""
|
"""Gets an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC):
|
|||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Updates an image record."""
|
"""Updates an image record."""
|
||||||
@ -89,7 +89,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
def delete(self, image_name: str) -> None:
|
||||||
"""Deletes an image record."""
|
"""Deletes an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -196,9 +196,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(
|
def get(self, image_name: str) -> Union[ImageRecord, None]:
|
||||||
self, image_origin: ResourceOrigin, image_name: str
|
|
||||||
) -> Union[ImageRecord, None]:
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
@ -225,7 +223,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
@ -294,9 +291,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
|
|
||||||
if categories is not None:
|
if categories is not None:
|
||||||
## Convert the enum values to unique list of strings
|
## Convert the enum values to unique list of strings
|
||||||
category_strings = list(
|
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||||
map(lambda c: c.value, set(categories))
|
|
||||||
)
|
|
||||||
# Create the correct length of placeholders
|
# Create the correct length of placeholders
|
||||||
placeholders = ",".join("?" * len(category_strings))
|
placeholders = ",".join("?" * len(category_strings))
|
||||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||||
@ -337,7 +332,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
items=images, offset=offset, limit=limit, total=count
|
items=images, offset=offset, limit=limit, total=count
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
def delete(self, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
|
@ -57,7 +57,6 @@ class ImageServiceABC(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||||
"""Gets an image as a PIL image."""
|
"""Gets an image as a PIL image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
def get_record(self, image_name: str) -> ImageRecord:
|
||||||
"""Gets an image record."""
|
"""Gets an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
"""Gets an image DTO."""
|
"""Gets an image DTO."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
|
def get_path(self, image_name: str) -> str:
|
||||||
"""Gets an image's path."""
|
"""Gets an image's path."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -90,9 +89,7 @@ class ImageServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_url(
|
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Gets an image's or thumbnail's URL."""
|
"""Gets an image's or thumbnail's URL."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -109,7 +106,7 @@ class ImageServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
def delete(self, image_name: str):
|
||||||
"""Deletes an image."""
|
"""Deletes an image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -206,16 +203,13 @@ class ImageService(ImageServiceABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._services.files.save(
|
self._services.files.save(
|
||||||
image_origin=image_origin,
|
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image=image,
|
image=image,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_url = self._services.urls.get_image_url(image_origin, image_name)
|
image_url = self._services.urls.get_image_url(image_name)
|
||||||
thumbnail_url = self._services.urls.get_image_url(
|
thumbnail_url = self._services.urls.get_image_url(image_name, True)
|
||||||
image_origin, image_name, True
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
# Non-nullable fields
|
# Non-nullable fields
|
||||||
@ -249,13 +243,12 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
self._services.records.update(image_name, image_origin, changes)
|
self._services.records.update(image_name, changes)
|
||||||
return self.get_dto(image_origin, image_name)
|
return self.get_dto(image_name)
|
||||||
except ImageRecordSaveException:
|
except ImageRecordSaveException:
|
||||||
self._services.logger.error("Failed to update image record")
|
self._services.logger.error("Failed to update image record")
|
||||||
raise
|
raise
|
||||||
@ -263,9 +256,9 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem updating image record")
|
self._services.logger.error("Problem updating image record")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
return self._services.files.get(image_origin, image_name)
|
return self._services.files.get(image_name)
|
||||||
except ImageFileNotFoundException:
|
except ImageFileNotFoundException:
|
||||||
self._services.logger.error("Failed to get image file")
|
self._services.logger.error("Failed to get image file")
|
||||||
raise
|
raise
|
||||||
@ -273,9 +266,9 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image file")
|
self._services.logger.error("Problem getting image file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
def get_record(self, image_name: str) -> ImageRecord:
|
||||||
try:
|
try:
|
||||||
return self._services.records.get(image_origin, image_name)
|
return self._services.records.get(image_name)
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self._services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
@ -283,14 +276,14 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image record")
|
self._services.logger.error("Problem getting image record")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
image_record = self._services.records.get(image_origin, image_name)
|
image_record = self._services.records.get(image_name)
|
||||||
|
|
||||||
image_dto = image_record_to_dto(
|
image_dto = image_record_to_dto(
|
||||||
image_record,
|
image_record,
|
||||||
self._services.urls.get_image_url(image_origin, image_name),
|
self._services.urls.get_image_url(image_name),
|
||||||
self._services.urls.get_image_url(image_origin, image_name, True),
|
self._services.urls.get_image_url(image_name, True),
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
@ -301,11 +294,9 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image DTO")
|
self._services.logger.error("Problem getting image DTO")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_path(
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
try:
|
try:
|
||||||
return self._services.files.get_path(image_origin, image_name, thumbnail)
|
return self._services.files.get_path(image_name, thumbnail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image path")
|
self._services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
@ -317,11 +308,9 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem validating image path")
|
self._services.logger.error("Problem validating image path")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_url(
|
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
try:
|
try:
|
||||||
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
|
return self._services.urls.get_image_url(image_name, thumbnail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image path")
|
self._services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
@ -347,10 +336,8 @@ class ImageService(ImageServiceABC):
|
|||||||
map(
|
map(
|
||||||
lambda r: image_record_to_dto(
|
lambda r: image_record_to_dto(
|
||||||
r,
|
r,
|
||||||
self._services.urls.get_image_url(r.image_origin, r.image_name),
|
self._services.urls.get_image_url(r.image_name),
|
||||||
self._services.urls.get_image_url(
|
self._services.urls.get_image_url(r.image_name, True),
|
||||||
r.image_origin, r.image_name, True
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
results.items,
|
results.items,
|
||||||
)
|
)
|
||||||
@ -366,10 +353,10 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting paginated image DTOs")
|
self._services.logger.error("Problem getting paginated image DTOs")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
def delete(self, image_name: str):
|
||||||
try:
|
try:
|
||||||
self._services.files.delete(image_origin, image_name)
|
self._services.files.delete(image_name)
|
||||||
self._services.records.delete(image_origin, image_name)
|
self._services.records.delete(image_name)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image record")
|
self._services.logger.error(f"Failed to delete image record")
|
||||||
raise
|
raise
|
||||||
|
@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel):
|
|||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
"""The unique name of the image."""
|
"""The unique name of the image."""
|
||||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
|
||||||
"""The origin of the image."""
|
|
||||||
image_url: str = Field(description="The URL of the image.")
|
image_url: str = Field(description="The URL of the image.")
|
||||||
"""The URL of the image."""
|
"""The URL of the image."""
|
||||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||||
|
@ -1,17 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from invokeai.app.models.image import ResourceOrigin
|
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
|
||||||
|
|
||||||
|
|
||||||
class UrlServiceBase(ABC):
|
class UrlServiceBase(ABC):
|
||||||
"""Responsible for building URLs for resources."""
|
"""Responsible for building URLs for resources."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_image_url(
|
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Gets the URL for an image or thumbnail."""
|
"""Gets the URL for an image or thumbnail."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
def __init__(self, base_url: str = "api/v1"):
|
def __init__(self, base_url: str = "api/v1"):
|
||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
|
|
||||||
def get_image_url(
|
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
image_basename = os.path.basename(image_name)
|
image_basename = os.path.basename(image_name)
|
||||||
|
|
||||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||||
if thumbnail:
|
if thumbnail:
|
||||||
return (
|
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
||||||
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
|
|
||||||
)
|
|
||||||
|
|
||||||
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
|
return f"{self._base_url}/images/{image_basename}"
|
||||||
|
Reference in New Issue
Block a user