mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): refactor image types
- Remove `ImageType` entirely, it is confusing - Create `ResourceOrigin`, may be `internal` or `external` - Revamp `ImageCategory`, may be `general`, `mask`, `control`, `user`, `other`. Expect to add more as time goes on - Update images `list` route to accept `include_categories` OR `exclude_categories` query parameters to afford finer-grained querying. All services are updated to accomodate this change. The new setup should account for our types of images, including the combinations we couldn't really handle until now: - Canvas init and masks - Canvas when saved-to-gallery or merged
This commit is contained in:
parent
fd47e70c92
commit
160267c71a
@ -1,39 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
|
||||||
|
|
||||||
|
|
||||||
class ImageResponseMetadata(BaseModel):
|
|
||||||
"""An image's metadata. Used only in HTTP responses."""
|
|
||||||
|
|
||||||
created: int = Field(description="The creation timestamp of the image")
|
|
||||||
width: int = Field(description="The width of the image in pixels")
|
|
||||||
height: int = Field(description="The height of the image in pixels")
|
|
||||||
# invokeai: Optional[InvokeAIMetadata] = Field(
|
|
||||||
# description="The image's InvokeAI-specific metadata"
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
class ImageResponse(BaseModel):
|
|
||||||
"""The response type for images"""
|
|
||||||
|
|
||||||
image_type: ImageType = Field(description="The type of the image")
|
|
||||||
image_name: str = Field(description="The name of the image")
|
|
||||||
image_url: str = Field(description="The url of the image")
|
|
||||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
|
||||||
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressImage(BaseModel):
|
|
||||||
"""The progress image sent intermittently during processing"""
|
|
||||||
|
|
||||||
width: int = Field(description="The effective width of the image in pixels")
|
|
||||||
height: int = Field(description="The effective height of the image in pixels")
|
|
||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
|
||||||
|
|
||||||
|
|
||||||
class SavedImage(BaseModel):
|
|
||||||
image_name: str = Field(description="The name of the saved image")
|
|
||||||
thumbnail_name: str = Field(description="The name of the saved thumbnail")
|
|
||||||
created: int = Field(description="The created timestamp of the saved image")
|
|
@ -6,7 +6,7 @@ from fastapi.responses import FileResponse
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from invokeai.app.models.image import (
|
from invokeai.app.models.image import (
|
||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageType,
|
ResourceOrigin,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.models.image_record import (
|
from invokeai.app.services.models.image_record import (
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
@ -36,9 +36,6 @@ async def upload_image(
|
|||||||
response: Response,
|
response: Response,
|
||||||
image_category: ImageCategory = Query(description="The category of the image"),
|
image_category: ImageCategory = Query(description="The category of the image"),
|
||||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||||
show_in_gallery: bool = Query(
|
|
||||||
description="Whether this image should be shown in the gallery"
|
|
||||||
),
|
|
||||||
session_id: Optional[str] = Query(
|
session_id: Optional[str] = Query(
|
||||||
default=None, description="The session ID associated with this upload, if any"
|
default=None, description="The session ID associated with this upload, if any"
|
||||||
),
|
),
|
||||||
@ -58,11 +55,10 @@ async def upload_image(
|
|||||||
try:
|
try:
|
||||||
image_dto = ApiDependencies.invoker.services.images.create(
|
image_dto = ApiDependencies.invoker.services.images.create(
|
||||||
image=pil_image,
|
image=pil_image,
|
||||||
image_type=ImageType.UPLOAD,
|
image_origin=ResourceOrigin.EXTERNAL,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
is_intermediate=is_intermediate,
|
is_intermediate=is_intermediate,
|
||||||
show_in_gallery=show_in_gallery,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response.status_code = 201
|
response.status_code = 201
|
||||||
@ -73,27 +69,27 @@ async def upload_image(
|
|||||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||||
|
|
||||||
|
|
||||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
|
||||||
async def delete_image(
|
async def delete_image(
|
||||||
image_type: ImageType = Path(description="The type of image to delete"),
|
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
image_name: str = Path(description="The name of the image to delete"),
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Deletes an image"""
|
"""Deletes an image"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.images.delete(image_type, image_name)
|
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: Does this need any exception handling at all?
|
# TODO: Does this need any exception handling at all?
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@images_router.patch(
|
@images_router.patch(
|
||||||
"/{image_type}/{image_name}",
|
"/{image_origin}/{image_name}",
|
||||||
operation_id="update_image",
|
operation_id="update_image",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
async def update_image(
|
async def update_image(
|
||||||
image_type: ImageType = Path(description="The type of image to update"),
|
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
|
||||||
image_name: str = Path(description="The name of the image to update"),
|
image_name: str = Path(description="The name of the image to update"),
|
||||||
image_changes: ImageRecordChanges = Body(
|
image_changes: ImageRecordChanges = Body(
|
||||||
description="The changes to apply to the image"
|
description="The changes to apply to the image"
|
||||||
@ -103,31 +99,31 @@ async def update_image(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.update(
|
return ApiDependencies.invoker.services.images.update(
|
||||||
image_type, image_name, image_changes
|
image_origin, image_name, image_changes
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail="Failed to update image")
|
raise HTTPException(status_code=400, detail="Failed to update image")
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_type}/{image_name}/metadata",
|
"/{image_origin}/{image_name}/metadata",
|
||||||
operation_id="get_image_metadata",
|
operation_id="get_image_metadata",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
async def get_image_metadata(
|
async def get_image_metadata(
|
||||||
image_type: ImageType = Path(description="The type of image to get"),
|
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
|
||||||
image_name: str = Path(description="The name of image to get"),
|
image_name: str = Path(description="The name of image to get"),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Gets an image's metadata"""
|
"""Gets an image's metadata"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return ApiDependencies.invoker.services.images.get_dto(image_type, image_name)
|
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_type}/{image_name}",
|
"/{image_origin}/{image_name}",
|
||||||
operation_id="get_image_full",
|
operation_id="get_image_full",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@ -139,7 +135,7 @@ async def get_image_metadata(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_image_full(
|
async def get_image_full(
|
||||||
image_type: ImageType = Path(
|
image_origin: ResourceOrigin = Path(
|
||||||
description="The type of full-resolution image file to get"
|
description="The type of full-resolution image file to get"
|
||||||
),
|
),
|
||||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||||
@ -147,7 +143,7 @@ async def get_image_full(
|
|||||||
"""Gets a full-resolution image file"""
|
"""Gets a full-resolution image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
|
||||||
|
|
||||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
@ -163,7 +159,7 @@ async def get_image_full(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_type}/{image_name}/thumbnail",
|
"/{image_origin}/{image_name}/thumbnail",
|
||||||
operation_id="get_image_thumbnail",
|
operation_id="get_image_thumbnail",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@ -175,14 +171,14 @@ async def get_image_full(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_image_thumbnail(
|
async def get_image_thumbnail(
|
||||||
image_type: ImageType = Path(description="The type of thumbnail image file to get"),
|
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
|
||||||
image_name: str = Path(description="The name of thumbnail image file to get"),
|
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||||
) -> FileResponse:
|
) -> FileResponse:
|
||||||
"""Gets a thumbnail image file"""
|
"""Gets a thumbnail image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
image_type, image_name, thumbnail=True
|
image_origin, image_name, thumbnail=True
|
||||||
)
|
)
|
||||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
@ -195,25 +191,25 @@ async def get_image_thumbnail(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_type}/{image_name}/urls",
|
"/{image_origin}/{image_name}/urls",
|
||||||
operation_id="get_image_urls",
|
operation_id="get_image_urls",
|
||||||
response_model=ImageUrlsDTO,
|
response_model=ImageUrlsDTO,
|
||||||
)
|
)
|
||||||
async def get_image_urls(
|
async def get_image_urls(
|
||||||
image_type: ImageType = Path(description="The type of the image whose URL to get"),
|
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
|
||||||
image_name: str = Path(description="The name of the image whose URL to get"),
|
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||||
) -> ImageUrlsDTO:
|
) -> ImageUrlsDTO:
|
||||||
"""Gets an image and thumbnail URL"""
|
"""Gets an image and thumbnail URL"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
image_url = ApiDependencies.invoker.services.images.get_url(
|
image_url = ApiDependencies.invoker.services.images.get_url(
|
||||||
image_type, image_name
|
image_origin, image_name
|
||||||
)
|
)
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||||
image_type, image_name, thumbnail=True
|
image_origin, image_name, thumbnail=True
|
||||||
)
|
)
|
||||||
return ImageUrlsDTO(
|
return ImageUrlsDTO(
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
@ -228,30 +224,33 @@ async def get_image_urls(
|
|||||||
response_model=PaginatedResults[ImageDTO],
|
response_model=PaginatedResults[ImageDTO],
|
||||||
)
|
)
|
||||||
async def list_images_with_metadata(
|
async def list_images_with_metadata(
|
||||||
image_type: Optional[ImageType] = Query(
|
image_origin: Optional[ResourceOrigin] = Query(
|
||||||
default=None, description="The type of images to list"
|
default=None, description="The origin of images to list"
|
||||||
),
|
),
|
||||||
image_category: Optional[ImageCategory] = Query(
|
include_categories: Optional[list[ImageCategory]] = Query(
|
||||||
default=None, description="The kind of images to list"
|
default=None, description="The categories of image to include"
|
||||||
|
),
|
||||||
|
exclude_categories: Optional[list[ImageCategory]] = Query(
|
||||||
|
default=None, description="The categories of image to exclude"
|
||||||
),
|
),
|
||||||
is_intermediate: Optional[bool] = Query(
|
is_intermediate: Optional[bool] = Query(
|
||||||
default=None, description="Whether to list intermediate images"
|
default=None, description="Whether to list intermediate images"
|
||||||
),
|
),
|
||||||
show_in_gallery: Optional[bool] = Query(
|
|
||||||
default=None, description="Whether to list images that show in the gallery"
|
|
||||||
),
|
|
||||||
page: int = Query(default=0, description="The page of images to get"),
|
page: int = Query(default=0, description="The page of images to get"),
|
||||||
per_page: int = Query(default=10, description="The number of images per page"),
|
per_page: int = Query(default=10, description="The number of images per page"),
|
||||||
) -> PaginatedResults[ImageDTO]:
|
) -> PaginatedResults[ImageDTO]:
|
||||||
"""Gets a list of images"""
|
"""Gets a list of images"""
|
||||||
|
|
||||||
|
if include_categories is not None and exclude_categories is not None:
|
||||||
|
raise HTTPException(status_code=400, detail="Cannot use both 'include_category' and 'exclude_category' at the same time.")
|
||||||
|
|
||||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||||
page,
|
page,
|
||||||
per_page,
|
per_page,
|
||||||
image_type,
|
image_origin,
|
||||||
image_category,
|
include_categories,
|
||||||
|
exclude_categories,
|
||||||
is_intermediate,
|
is_intermediate,
|
||||||
show_in_gallery,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_dtos
|
return image_dtos
|
||||||
|
@ -7,7 +7,7 @@ import numpy
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
mask = context.services.images.get_pil_image(
|
mask = context.services.images.get_pil_image(
|
||||||
self.mask.image_type, self.mask.image_name
|
self.mask.image_origin, self.mask.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to cv image/mask
|
# Convert to cv image/mask
|
||||||
@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image_inpainted,
|
image=image_inpainted,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -67,7 +67,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -10,9 +10,9 @@ import torch
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
@ -120,7 +120,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=generate_output.image,
|
image=generate_output.image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
@ -130,7 +130,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -170,7 +170,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
None
|
None
|
||||||
if self.image is None
|
if self.image is None
|
||||||
else context.services.images.get_pil_image(
|
else context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=generator_output.image,
|
image=generator_output.image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
@ -211,7 +211,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -283,13 +283,13 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
None
|
None
|
||||||
if self.image is None
|
if self.image is None
|
||||||
else context.services.images.get_pil_image(
|
else context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mask = (
|
mask = (
|
||||||
None
|
None
|
||||||
if self.mask is None
|
if self.mask is None
|
||||||
else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name)
|
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
@ -317,7 +317,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=generator_output.image,
|
image=generator_output.image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
@ -327,7 +327,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -7,7 +7,7 @@ import numpy
|
|||||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ImageField, ImageType
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=self.image.image_name,
|
image_name=self.image.image_name,
|
||||||
image_type=self.image.image_type,
|
image_origin=self.image.image_origin,
|
||||||
),
|
),
|
||||||
width=image.width,
|
width=image.width,
|
||||||
height=image.height,
|
height=image.height,
|
||||||
@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
if image:
|
if image:
|
||||||
image.show()
|
image.show()
|
||||||
@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=self.image.image_name,
|
image_name=self.image.image_name,
|
||||||
image_type=self.image.image_type,
|
image_origin=self.image.image_origin,
|
||||||
),
|
),
|
||||||
width=image.width,
|
width=image.width,
|
||||||
height=image.height,
|
height=image.height,
|
||||||
@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
image_crop = Image.new(
|
image_crop = Image.new(
|
||||||
@ -139,7 +139,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image_crop,
|
image=image_crop,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -149,7 +149,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -172,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get_pil_image(
|
base_image = context.services.images.get_pil_image(
|
||||||
self.base_image.image_type, self.base_image.image_name
|
self.base_image.image_origin, self.base_image.image_name
|
||||||
)
|
)
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
mask = (
|
mask = (
|
||||||
None
|
None
|
||||||
if self.mask is None
|
if self.mask is None
|
||||||
else ImageOps.invert(
|
else ImageOps.invert(
|
||||||
context.services.images.get_pil_image(
|
context.services.images.get_pil_image(
|
||||||
self.mask.image_type, self.mask.image_name
|
self.mask.image_origin, self.mask.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -201,7 +201,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=new_image,
|
image=new_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -211,7 +211,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -231,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
image_mask = image.split()[-1]
|
image_mask = image.split()[-1]
|
||||||
@ -240,7 +240,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image_mask,
|
image=image_mask,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.MASK,
|
image_category=ImageCategory.MASK,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -249,7 +249,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
return MaskOutput(
|
return MaskOutput(
|
||||||
mask=ImageField(
|
mask=ImageField(
|
||||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -269,17 +269,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image1 = context.services.images.get_pil_image(
|
image1 = context.services.images.get_pil_image(
|
||||||
self.image1.image_type, self.image1.image_name
|
self.image1.image_origin, self.image1.image_name
|
||||||
)
|
)
|
||||||
image2 = context.services.images.get_pil_image(
|
image2 = context.services.images.get_pil_image(
|
||||||
self.image2.image_type, self.image2.image_name
|
self.image2.image_origin, self.image2.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
multiply_image = ImageChops.multiply(image1, image2)
|
multiply_image = ImageChops.multiply(image1, image2)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=multiply_image,
|
image=multiply_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -288,7 +288,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -311,14 +311,14 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
channel_image = image.getchannel(self.channel)
|
channel_image = image.getchannel(self.channel)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=channel_image,
|
image=channel_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -327,7 +327,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -350,14 +350,14 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
converted_image = image.convert(self.mode)
|
converted_image = image.convert(self.mode)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=converted_image,
|
image=converted_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -366,7 +366,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_type=image_dto.image_type, image_name=image_dto.image_name
|
image_origin=image_dto.image_origin, image_name=image_dto.image_name
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -387,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
blur = (
|
blur = (
|
||||||
@ -399,7 +399,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=blur_image,
|
image=blur_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -409,7 +409,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -430,7 +430,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
||||||
@ -440,7 +440,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=lerp_image,
|
image=lerp_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -450,7 +450,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -471,7 +471,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||||
@ -486,7 +486,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=ilerp_image,
|
image=ilerp_image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -496,7 +496,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput
|
|||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ColorField, ImageCategory, ImageField, ImageType
|
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
@ -145,7 +145,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=infilled,
|
image=infilled,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -155,7 +155,7 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -180,7 +180,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
infilled = tile_fill_missing(
|
infilled = tile_fill_missing(
|
||||||
@ -190,7 +190,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=infilled,
|
image=infilled,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -200,7 +200,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -218,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
@ -228,7 +228,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=infilled,
|
image=infilled,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -238,7 +238,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -28,7 +28,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
|||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_file_storage import ImageType
|
from ..services.image_file_storage import ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
@ -468,7 +468,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
# and gnenerate unique image_name
|
# and gnenerate unique image_name
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image,
|
image=image,
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
@ -478,7 +478,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
@ -576,7 +576,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# )
|
# )
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: this only really needs the vae
|
# TODO: this only really needs the vae
|
||||||
|
@ -2,7 +2,7 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
# TODO: can this return multiple results?
|
# TODO: can this return multiple results?
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -53,7 +53,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -4,7 +4,7 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_origin, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
image_list=[[image, 0]],
|
image_list=[[image, 0]],
|
||||||
@ -45,7 +45,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
# TODO: can this return multiple results?
|
# TODO: can this return multiple results?
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
image_type=ImageType.RESULT,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
@ -55,7 +55,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(
|
image=ImageField(
|
||||||
image_name=image_dto.image_name,
|
image_name=image_dto.image_name,
|
||||||
image_type=image_dto.image_type,
|
image_origin=image_dto.image_origin,
|
||||||
),
|
),
|
||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
|
@ -5,30 +5,52 @@ from pydantic import BaseModel, Field
|
|||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
|
|
||||||
|
|
||||||
class ImageType(str, Enum, metaclass=MetaEnum):
|
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||||
"""The type of an image."""
|
"""The origin of a resource (eg image).
|
||||||
|
|
||||||
RESULT = "results"
|
- INTERNAL: The resource was created by the application.
|
||||||
UPLOAD = "uploads"
|
- EXTERNAL: The resource was not created by the application.
|
||||||
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
INTERNAL = "internal"
|
||||||
|
"""The resource was created by the application."""
|
||||||
|
EXTERNAL = "external"
|
||||||
|
"""The resource was not created by the application.
|
||||||
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class InvalidImageTypeException(ValueError):
|
class InvalidOriginException(ValueError):
|
||||||
"""Raised when a provided value is not a valid ImageType.
|
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||||
|
|
||||||
Subclasses `ValueError`.
|
Subclasses `ValueError`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, message="Invalid image type."):
|
def __init__(self, message="Invalid resource origin."):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||||
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
|
"""The category of an image.
|
||||||
|
|
||||||
|
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||||
|
- MASK: The image is a mask image.
|
||||||
|
- CONTROL: The image is a ControlNet control image.
|
||||||
|
- USER: The image is a user-provide image.
|
||||||
|
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
GENERAL = "general"
|
GENERAL = "general"
|
||||||
CONTROL = "control"
|
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||||
MASK = "mask"
|
MASK = "mask"
|
||||||
|
"""MASK: The image is a mask image."""
|
||||||
|
CONTROL = "control"
|
||||||
|
"""CONTROL: The image is a ControlNet control image."""
|
||||||
|
USER = "user"
|
||||||
|
"""USER: The image is a user-provide image."""
|
||||||
OTHER = "other"
|
OTHER = "other"
|
||||||
|
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||||
|
|
||||||
|
|
||||||
class InvalidImageCategoryException(ValueError):
|
class InvalidImageCategoryException(ValueError):
|
||||||
@ -44,13 +66,13 @@ class InvalidImageCategoryException(ValueError):
|
|||||||
class ImageField(BaseModel):
|
class ImageField(BaseModel):
|
||||||
"""An image field used for passing image objects between invocations"""
|
"""An image field used for passing image objects between invocations"""
|
||||||
|
|
||||||
image_type: ImageType = Field(
|
image_origin: ResourceOrigin = Field(
|
||||||
default=ImageType.RESULT, description="The type of the image"
|
default=ResourceOrigin.INTERNAL, description="The type of the image"
|
||||||
)
|
)
|
||||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["image_type", "image_name"]}
|
schema_extra = {"required": ["image_origin", "image_name"]}
|
||||||
|
|
||||||
|
|
||||||
class ColorField(BaseModel):
|
class ColorField(BaseModel):
|
||||||
@ -61,3 +83,11 @@ class ColorField(BaseModel):
|
|||||||
|
|
||||||
def tuple(self) -> Tuple[int, int, int, int]:
|
def tuple(self) -> Tuple[int, int, int, int]:
|
||||||
return (self.r, self.g, self.b, self.a)
|
return (self.r, self.g, self.b, self.a)
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressImage(BaseModel):
|
||||||
|
"""The progress image sent intermittently during processing"""
|
||||||
|
|
||||||
|
width: int = Field(description="The effective width of the image in pixels")
|
||||||
|
height: int = Field(description="The effective height of the image in pixels")
|
||||||
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
from invokeai.app.api.models.images import ProgressImage
|
from invokeai.app.models.image import ProgressImage
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType
|
|||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ResourceOrigin
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC):
|
|||||||
"""Low-level service responsible for storing and retrieving image files."""
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||||
"""Retrieves an image as PIL Image."""
|
"""Retrieves an image as PIL Image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Gets the internal path to an image or thumbnail."""
|
"""Gets the internal path to an image or thumbnail."""
|
||||||
pass
|
pass
|
||||||
@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
metadata: Optional[ImageMetadata] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||||
"""Deletes an image and its thumbnail (if one exists)."""
|
"""Deletes an image and its thumbnail (if one exists)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||||
for image_type in ImageType:
|
for image_origin in ResourceOrigin:
|
||||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
Path(os.path.join(output_folder, image_origin)).mkdir(
|
||||||
parents=True, exist_ok=True
|
parents=True, exist_ok=True
|
||||||
)
|
)
|
||||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
|
||||||
parents=True, exist_ok=True
|
parents=True, exist_ok=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_origin, image_name)
|
||||||
cache_item = self.__get_cache(image_path)
|
cache_item = self.__get_cache(image_path)
|
||||||
if cache_item:
|
if cache_item:
|
||||||
return cache_item
|
return cache_item
|
||||||
@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
metadata: Optional[ImageMetadata] = None,
|
metadata: Optional[ImageMetadata] = None,
|
||||||
thumbnail_size: int = 256,
|
thumbnail_size: int = 256,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_origin, image_name)
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
image.save(image_path, "PNG")
|
image.save(image_path, "PNG")
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
|
||||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||||
thumbnail_image.save(thumbnail_path)
|
thumbnail_image.save(thumbnail_path)
|
||||||
|
|
||||||
@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImageFileSaveException from e
|
raise ImageFileSaveException from e
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
basename = os.path.basename(image_name)
|
basename = os.path.basename(image_name)
|
||||||
image_path = self.get_path(image_type, basename)
|
image_path = self.get_path(image_origin, basename)
|
||||||
|
|
||||||
if os.path.exists(image_path):
|
if os.path.exists(image_path):
|
||||||
send2trash(image_path)
|
send2trash(image_path)
|
||||||
@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
|
||||||
|
|
||||||
if os.path.exists(thumbnail_path):
|
if os.path.exists(thumbnail_path):
|
||||||
send2trash(thumbnail_path)
|
send2trash(thumbnail_path)
|
||||||
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
# strip out any relative path shenanigans
|
# strip out any relative path shenanigans
|
||||||
basename = os.path.basename(image_name)
|
basename = os.path.basename(image_name)
|
||||||
@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
if thumbnail:
|
if thumbnail:
|
||||||
thumbnail_name = get_thumbnail_name(basename)
|
thumbnail_name = get_thumbnail_name(basename)
|
||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
self.__output_folder, image_type, "thumbnails", thumbnail_name
|
self.__output_folder, image_origin, "thumbnails", thumbnail_name
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
path = os.path.join(self.__output_folder, image_type, basename)
|
path = os.path.join(self.__output_folder, image_origin, basename)
|
||||||
|
|
||||||
abspath = os.path.abspath(path)
|
abspath = os.path.abspath(path)
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from typing import Optional, Union
|
|||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.models.image import (
|
from invokeai.app.models.image import (
|
||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageType,
|
ResourceOrigin,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.models.image_record import (
|
from invokeai.app.services.models.image_record import (
|
||||||
ImageRecord,
|
ImageRecord,
|
||||||
@ -46,7 +46,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
# TODO: Implement an `update()` method
|
# TODO: Implement an `update()` method
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||||
"""Gets an image record."""
|
"""Gets an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Updates an image record."""
|
"""Updates an image record."""
|
||||||
@ -65,10 +65,10 @@ class ImageRecordStorageBase(ABC):
|
|||||||
self,
|
self,
|
||||||
page: int = 0,
|
page: int = 0,
|
||||||
per_page: int = 10,
|
per_page: int = 10,
|
||||||
image_type: Optional[ImageType] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
image_category: Optional[ImageCategory] = None,
|
include_categories: Optional[list[ImageCategory]] = None,
|
||||||
|
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
show_in_gallery: Optional[bool] = None,
|
|
||||||
) -> PaginatedResults[ImageRecord]:
|
) -> PaginatedResults[ImageRecord]:
|
||||||
"""Gets a page of image records."""
|
"""Gets a page of image records."""
|
||||||
pass
|
pass
|
||||||
@ -76,7 +76,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||||
"""Deletes an image record."""
|
"""Deletes an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
@ -92,7 +92,6 @@ class ImageRecordStorageBase(ABC):
|
|||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[ImageMetadata],
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
show_in_gallery: bool = True,
|
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
pass
|
pass
|
||||||
@ -131,7 +130,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
CREATE TABLE IF NOT EXISTS images (
|
CREATE TABLE IF NOT EXISTS images (
|
||||||
image_name TEXT NOT NULL PRIMARY KEY,
|
image_name TEXT NOT NULL PRIMARY KEY,
|
||||||
-- This is an enum in python, unrestricted string here for flexibility
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
image_type TEXT NOT NULL,
|
image_origin TEXT NOT NULL,
|
||||||
-- This is an enum in python, unrestricted string here for flexibility
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
image_category TEXT NOT NULL,
|
image_category TEXT NOT NULL,
|
||||||
width INTEGER NOT NULL,
|
width INTEGER NOT NULL,
|
||||||
@ -139,7 +138,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
session_id TEXT,
|
session_id TEXT,
|
||||||
node_id TEXT,
|
node_id TEXT,
|
||||||
metadata TEXT,
|
metadata TEXT,
|
||||||
show_in_gallery BOOLEAN DEFAULT TRUE,
|
|
||||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- Updated via trigger
|
-- Updated via trigger
|
||||||
@ -158,7 +156,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
)
|
)
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
|
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -185,7 +183,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
|
def get(
|
||||||
|
self, image_origin: ResourceOrigin, image_name: str
|
||||||
|
) -> Union[ImageRecord, None]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
@ -212,7 +212,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
@ -249,71 +249,72 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
self,
|
self,
|
||||||
page: int = 0,
|
page: int = 0,
|
||||||
per_page: int = 10,
|
per_page: int = 10,
|
||||||
image_type: Optional[ImageType] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
image_category: Optional[ImageCategory] = None,
|
include_categories: Optional[list[ImageCategory]] = None,
|
||||||
|
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
show_in_gallery: Optional[bool] = None,
|
|
||||||
) -> PaginatedResults[ImageRecord]:
|
) -> PaginatedResults[ImageRecord]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
# Manually build two queries - one for the count, one for the records
|
# Manually build two queries - one for the count, one for the records
|
||||||
|
|
||||||
count_query = """--sql
|
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
||||||
SELECT COUNT(*) FROM images WHERE 1=1
|
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
||||||
"""
|
|
||||||
|
|
||||||
images_query = """--sql
|
|
||||||
SELECT * FROM images WHERE 1=1
|
|
||||||
"""
|
|
||||||
|
|
||||||
query_conditions = ""
|
query_conditions = ""
|
||||||
query_params = []
|
query_params = []
|
||||||
|
|
||||||
if image_type is not None:
|
if image_origin is not None:
|
||||||
query_conditions += """--sql
|
query_conditions += f"""AND image_origin = ?\n"""
|
||||||
AND image_type = ?
|
query_params.append(image_origin.value)
|
||||||
"""
|
|
||||||
query_params.append(image_type.value)
|
|
||||||
|
|
||||||
if image_category is not None:
|
if include_categories is not None:
|
||||||
query_conditions += """--sql
|
## Convert the enum values to unique list of strings
|
||||||
AND image_category = ?
|
include_category_strings = list(
|
||||||
"""
|
map(lambda c: c.value, set(include_categories))
|
||||||
query_params.append(image_category.value)
|
)
|
||||||
|
# Create the correct length of placeholders
|
||||||
|
placeholders = ",".join("?" * len(include_category_strings))
|
||||||
|
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||||
|
|
||||||
|
# Unpack the included categories into the query params
|
||||||
|
query_params.append(*include_category_strings)
|
||||||
|
|
||||||
|
if exclude_categories is not None:
|
||||||
|
## Convert the enum values to unique list of strings
|
||||||
|
exclude_category_strings = list(
|
||||||
|
map(lambda c: c.value, set(exclude_categories))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the correct length of placeholders
|
||||||
|
placeholders = ",".join("?" * len(exclude_category_strings))
|
||||||
|
query_conditions += f"AND image_category NOT IN ( {placeholders} )\n"
|
||||||
|
|
||||||
|
# Unpack the included categories into the query params
|
||||||
|
query_params.append(*exclude_category_strings)
|
||||||
|
|
||||||
if is_intermediate is not None:
|
if is_intermediate is not None:
|
||||||
query_conditions += """--sql
|
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||||
AND is_intermediate = ?
|
|
||||||
"""
|
|
||||||
query_params.append(is_intermediate)
|
query_params.append(is_intermediate)
|
||||||
|
|
||||||
if show_in_gallery is not None:
|
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
||||||
query_conditions += """--sql
|
|
||||||
AND show_in_gallery = ?
|
|
||||||
"""
|
|
||||||
query_params.append(show_in_gallery)
|
|
||||||
|
|
||||||
query_pagination = """--sql
|
|
||||||
ORDER BY created_at DESC LIMIT ? OFFSET ?
|
|
||||||
"""
|
|
||||||
|
|
||||||
count_query += query_conditions + ";"
|
|
||||||
count_params = query_params.copy()
|
|
||||||
|
|
||||||
|
# Final images query with pagination
|
||||||
images_query += query_conditions + query_pagination + ";"
|
images_query += query_conditions + query_pagination + ";"
|
||||||
|
# Add all the parameters
|
||||||
images_params = query_params.copy()
|
images_params = query_params.copy()
|
||||||
images_params.append(per_page)
|
images_params.append(per_page)
|
||||||
images_params.append(page * per_page)
|
images_params.append(page * per_page)
|
||||||
|
# Build the list of images, deserializing each row
|
||||||
self._cursor.execute(images_query, images_params)
|
self._cursor.execute(images_query, images_params)
|
||||||
|
|
||||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
|
||||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||||
|
|
||||||
|
# Set up and execute the count query, without pagination
|
||||||
|
count_query += query_conditions + ";"
|
||||||
|
count_params = query_params.copy()
|
||||||
self._cursor.execute(count_query, count_params)
|
self._cursor.execute(count_query, count_params)
|
||||||
|
|
||||||
count = self._cursor.fetchone()[0]
|
count = self._cursor.fetchone()[0]
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
@ -327,7 +328,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
items=images, page=page, pages=pageCount, per_page=per_page, total=count
|
items=images, page=page, pages=pageCount, per_page=per_page, total=count
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -347,7 +348,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
session_id: Optional[str],
|
session_id: Optional[str],
|
||||||
width: int,
|
width: int,
|
||||||
@ -355,7 +356,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
node_id: Optional[str],
|
node_id: Optional[str],
|
||||||
metadata: Optional[ImageMetadata],
|
metadata: Optional[ImageMetadata],
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
show_in_gallery: bool = True,
|
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = (
|
metadata_json = (
|
||||||
@ -366,21 +366,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
"""--sql
|
"""--sql
|
||||||
INSERT OR IGNORE INTO images (
|
INSERT OR IGNORE INTO images (
|
||||||
image_name,
|
image_name,
|
||||||
image_type,
|
image_origin,
|
||||||
image_category,
|
image_category,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
node_id,
|
node_id,
|
||||||
session_id,
|
session_id,
|
||||||
metadata,
|
metadata,
|
||||||
is_intermediate,
|
is_intermediate
|
||||||
show_in_gallery
|
|
||||||
)
|
)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
image_name,
|
image_name,
|
||||||
image_type.value,
|
image_origin.value,
|
||||||
image_category.value,
|
image_category.value,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
@ -388,7 +387,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
session_id,
|
session_id,
|
||||||
metadata_json,
|
metadata_json,
|
||||||
is_intermediate,
|
is_intermediate,
|
||||||
show_in_gallery,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
|
@ -5,9 +5,9 @@ from PIL.Image import Image as PILImageType
|
|||||||
|
|
||||||
from invokeai.app.models.image import (
|
from invokeai.app.models.image import (
|
||||||
ImageCategory,
|
ImageCategory,
|
||||||
ImageType,
|
ResourceOrigin,
|
||||||
InvalidImageCategoryException,
|
InvalidImageCategoryException,
|
||||||
InvalidImageTypeException,
|
InvalidOriginException,
|
||||||
)
|
)
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.services.image_record_storage import (
|
from invokeai.app.services.image_record_storage import (
|
||||||
@ -44,12 +44,11 @@ class ImageServiceABC(ABC):
|
|||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
intermediate: bool = False,
|
intermediate: bool = False,
|
||||||
show_in_gallery: bool = True,
|
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Creates an image, storing the file and its metadata."""
|
"""Creates an image, storing the file and its metadata."""
|
||||||
pass
|
pass
|
||||||
@ -57,7 +56,7 @@ class ImageServiceABC(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||||
"""Gets an image as a PIL image."""
|
"""Gets an image as a PIL image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||||
"""Gets an image record."""
|
"""Gets an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||||
"""Gets an image DTO."""
|
"""Gets an image DTO."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
|
||||||
"""Gets an image's path."""
|
"""Gets an image's path."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -91,7 +90,7 @@ class ImageServiceABC(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_url(
|
def get_url(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Gets an image's or thumbnail's URL."""
|
"""Gets an image's or thumbnail's URL."""
|
||||||
pass
|
pass
|
||||||
@ -101,16 +100,16 @@ class ImageServiceABC(ABC):
|
|||||||
self,
|
self,
|
||||||
page: int = 0,
|
page: int = 0,
|
||||||
per_page: int = 10,
|
per_page: int = 10,
|
||||||
image_type: Optional[ImageType] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
image_category: Optional[ImageCategory] = None,
|
include_categories: Optional[list[ImageCategory]] = None,
|
||||||
|
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
show_in_gallery: Optional[bool] = None,
|
|
||||||
) -> PaginatedResults[ImageDTO]:
|
) -> PaginatedResults[ImageDTO]:
|
||||||
"""Gets a paginated list of image DTOs."""
|
"""Gets a paginated list of image DTOs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_type: ImageType, image_name: str):
|
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||||
"""Deletes an image."""
|
"""Deletes an image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -171,15 +170,14 @@ class ImageService(ImageServiceABC):
|
|||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
show_in_gallery: bool = True,
|
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
if image_type not in ImageType:
|
if image_origin not in ResourceOrigin:
|
||||||
raise InvalidImageTypeException
|
raise InvalidOriginException
|
||||||
|
|
||||||
if image_category not in ImageCategory:
|
if image_category not in ImageCategory:
|
||||||
raise InvalidImageCategoryException
|
raise InvalidImageCategoryException
|
||||||
@ -195,13 +193,12 @@ class ImageService(ImageServiceABC):
|
|||||||
created_at = self._services.records.save(
|
created_at = self._services.records.save(
|
||||||
# Non-nullable fields
|
# Non-nullable fields
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
# Meta fields
|
# Meta fields
|
||||||
is_intermediate=is_intermediate,
|
is_intermediate=is_intermediate,
|
||||||
show_in_gallery=show_in_gallery,
|
|
||||||
# Nullable fields
|
# Nullable fields
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@ -209,21 +206,21 @@ class ImageService(ImageServiceABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._services.files.save(
|
self._services.files.save(
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image=image,
|
image=image,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_url = self._services.urls.get_image_url(image_type, image_name)
|
image_url = self._services.urls.get_image_url(image_origin, image_name)
|
||||||
thumbnail_url = self._services.urls.get_image_url(
|
thumbnail_url = self._services.urls.get_image_url(
|
||||||
image_type, image_name, True
|
image_origin, image_name, True
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
# Non-nullable fields
|
# Non-nullable fields
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
@ -236,7 +233,6 @@ class ImageService(ImageServiceABC):
|
|||||||
updated_at=created_at, # this is always the same as the created_at at this time
|
updated_at=created_at, # this is always the same as the created_at at this time
|
||||||
deleted_at=None,
|
deleted_at=None,
|
||||||
is_intermediate=is_intermediate,
|
is_intermediate=is_intermediate,
|
||||||
show_in_gallery=show_in_gallery,
|
|
||||||
# Extra non-nullable fields for DTO
|
# Extra non-nullable fields for DTO
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
@ -253,13 +249,13 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
image_type: ImageType,
|
image_origin: ResourceOrigin,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
self._services.records.update(image_name, image_type, changes)
|
self._services.records.update(image_name, image_origin, changes)
|
||||||
return self.get_dto(image_type, image_name)
|
return self.get_dto(image_origin, image_name)
|
||||||
except ImageRecordSaveException:
|
except ImageRecordSaveException:
|
||||||
self._services.logger.error("Failed to update image record")
|
self._services.logger.error("Failed to update image record")
|
||||||
raise
|
raise
|
||||||
@ -267,9 +263,9 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem updating image record")
|
self._services.logger.error("Problem updating image record")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
return self._services.files.get(image_type, image_name)
|
return self._services.files.get(image_origin, image_name)
|
||||||
except ImageFileNotFoundException:
|
except ImageFileNotFoundException:
|
||||||
self._services.logger.error("Failed to get image file")
|
self._services.logger.error("Failed to get image file")
|
||||||
raise
|
raise
|
||||||
@ -277,9 +273,9 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image file")
|
self._services.logger.error("Problem getting image file")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||||
try:
|
try:
|
||||||
return self._services.records.get(image_type, image_name)
|
return self._services.records.get(image_origin, image_name)
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self._services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
@ -287,14 +283,14 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting image record")
|
self._services.logger.error("Problem getting image record")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
image_record = self._services.records.get(image_type, image_name)
|
image_record = self._services.records.get(image_origin, image_name)
|
||||||
|
|
||||||
image_dto = image_record_to_dto(
|
image_dto = image_record_to_dto(
|
||||||
image_record,
|
image_record,
|
||||||
self._services.urls.get_image_url(image_type, image_name),
|
self._services.urls.get_image_url(image_origin, image_name),
|
||||||
self._services.urls.get_image_url(image_type, image_name, True),
|
self._services.urls.get_image_url(image_origin, image_name, True),
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
@ -306,10 +302,10 @@ class ImageService(ImageServiceABC):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_path(
|
def get_path(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
return self._services.files.get_path(image_type, image_name, thumbnail)
|
return self._services.files.get_path(image_origin, image_name, thumbnail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image path")
|
self._services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
@ -322,10 +318,10 @@ class ImageService(ImageServiceABC):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_url(
|
def get_url(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image path")
|
self._services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
@ -334,28 +330,28 @@ class ImageService(ImageServiceABC):
|
|||||||
self,
|
self,
|
||||||
page: int = 0,
|
page: int = 0,
|
||||||
per_page: int = 10,
|
per_page: int = 10,
|
||||||
image_type: Optional[ImageType] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
image_category: Optional[ImageCategory] = None,
|
include_categories: Optional[list[ImageCategory]] = None,
|
||||||
|
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
show_in_gallery: Optional[bool] = None,
|
|
||||||
) -> PaginatedResults[ImageDTO]:
|
) -> PaginatedResults[ImageDTO]:
|
||||||
try:
|
try:
|
||||||
results = self._services.records.get_many(
|
results = self._services.records.get_many(
|
||||||
page,
|
page,
|
||||||
per_page,
|
per_page,
|
||||||
image_type,
|
image_origin,
|
||||||
image_category,
|
include_categories,
|
||||||
|
exclude_categories,
|
||||||
is_intermediate,
|
is_intermediate,
|
||||||
show_in_gallery,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
image_dtos = list(
|
image_dtos = list(
|
||||||
map(
|
map(
|
||||||
lambda r: image_record_to_dto(
|
lambda r: image_record_to_dto(
|
||||||
r,
|
r,
|
||||||
self._services.urls.get_image_url(r.image_type, r.image_name),
|
self._services.urls.get_image_url(r.image_origin, r.image_name),
|
||||||
self._services.urls.get_image_url(
|
self._services.urls.get_image_url(
|
||||||
r.image_type, r.image_name, True
|
r.image_origin, r.image_name, True
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
results.items,
|
results.items,
|
||||||
@ -373,10 +369,10 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.error("Problem getting paginated image DTOs")
|
self._services.logger.error("Problem getting paginated image DTOs")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str):
|
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||||
try:
|
try:
|
||||||
self._services.files.delete(image_type, image_name)
|
self._services.files.delete(image_origin, image_name)
|
||||||
self._services.records.delete(image_type, image_name)
|
self._services.records.delete(image_origin, image_name)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image record")
|
self._services.logger.error(f"Failed to delete image record")
|
||||||
raise
|
raise
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from pydantic import BaseModel, Extra, Field, StrictStr
|
from pydantic import BaseModel, Extra, Field, StrictStr
|
||||||
from invokeai.app.models.image import ImageCategory, ImageType
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
@ -11,8 +11,8 @@ class ImageRecord(BaseModel):
|
|||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
"""The unique name of the image."""
|
"""The unique name of the image."""
|
||||||
image_type: ImageType = Field(description="The type of the image.")
|
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||||
"""The type of the image."""
|
"""The origin of the image."""
|
||||||
image_category: ImageCategory = Field(description="The category of the image.")
|
image_category: ImageCategory = Field(description="The category of the image.")
|
||||||
"""The category of the image."""
|
"""The category of the image."""
|
||||||
width: int = Field(description="The width of the image in px.")
|
width: int = Field(description="The width of the image in px.")
|
||||||
@ -33,8 +33,6 @@ class ImageRecord(BaseModel):
|
|||||||
"""The deleted timestamp of the image."""
|
"""The deleted timestamp of the image."""
|
||||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||||
"""Whether this is an intermediate image."""
|
"""Whether this is an intermediate image."""
|
||||||
show_in_gallery: bool = Field(description="Whether this image should be shown in the gallery.")
|
|
||||||
"""Whether this image should be shown in the gallery."""
|
|
||||||
session_id: Optional[str] = Field(
|
session_id: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The session ID that generated this image, if it is a generated image.",
|
description="The session ID that generated this image, if it is a generated image.",
|
||||||
@ -76,8 +74,8 @@ class ImageUrlsDTO(BaseModel):
|
|||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
"""The unique name of the image."""
|
"""The unique name of the image."""
|
||||||
image_type: ImageType = Field(description="The type of the image.")
|
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||||
"""The type of the image."""
|
"""The origin of the image."""
|
||||||
image_url: str = Field(description="The URL of the image.")
|
image_url: str = Field(description="The URL of the image.")
|
||||||
"""The URL of the image."""
|
"""The URL of the image."""
|
||||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||||
@ -107,7 +105,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
image_name = image_dict.get("image_name", "unknown")
|
image_name = image_dict.get("image_name", "unknown")
|
||||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
|
||||||
image_category = ImageCategory(
|
image_category = ImageCategory(
|
||||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||||
)
|
)
|
||||||
@ -119,7 +117,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||||
is_intermediate = image_dict.get("is_intermediate", False)
|
is_intermediate = image_dict.get("is_intermediate", False)
|
||||||
show_in_gallery = image_dict.get("show_in_gallery", True)
|
|
||||||
|
|
||||||
raw_metadata = image_dict.get("metadata")
|
raw_metadata = image_dict.get("metadata")
|
||||||
|
|
||||||
@ -130,7 +127,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
|
|
||||||
return ImageRecord(
|
return ImageRecord(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_type=image_type,
|
image_origin=image_origin,
|
||||||
image_category=image_category,
|
image_category=image_category,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
@ -141,5 +138,4 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
updated_at=updated_at,
|
updated_at=updated_at,
|
||||||
deleted_at=deleted_at,
|
deleted_at=deleted_at,
|
||||||
is_intermediate=is_intermediate,
|
is_intermediate=is_intermediate,
|
||||||
show_in_gallery=show_in_gallery,
|
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ResourceOrigin
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||||
|
|
||||||
|
|
||||||
@ -10,7 +10,7 @@ class UrlServiceBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_image_url(
|
def get_image_url(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Gets the URL for an image or thumbnail."""
|
"""Gets the URL for an image or thumbnail."""
|
||||||
pass
|
pass
|
||||||
@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
|
|
||||||
def get_image_url(
|
def get_image_url(
|
||||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
image_basename = os.path.basename(image_name)
|
image_basename = os.path.basename(image_name)
|
||||||
|
|
||||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||||
if thumbnail:
|
if thumbnail:
|
||||||
return (
|
return (
|
||||||
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
|
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"{self._base_url}/images/{image_type.value}/{image_basename}"
|
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from invokeai.app.api.models.images import ProgressImage
|
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
|
from invokeai.app.models.image import ProgressImage
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
from ...backend.generator.base import Generator
|
from ...backend.generator.base import Generator
|
||||||
|
Loading…
Reference in New Issue
Block a user