feat(nodes): refactor image types

- Remove `ImageType` entirely, it is confusing
- Create `ResourceOrigin`, may be `internal` or `external`
- Revamp `ImageCategory`, may be `general`, `mask`, `control`, `user`, `other`. Expect to add more as time goes on
- Update images `list` route to accept `include_categories` OR `exclude_categories` query parameters to afford finer-grained querying. All services are updated to accomodate this change.

The new setup should account for our types of images, including the combinations we couldn't really handle until now:
- Canvas init and masks
- Canvas when saved-to-gallery or merged
This commit is contained in:
psychedelicious
2023-05-27 21:39:20 +10:00
committed by Kent Keirsey
parent fd47e70c92
commit 160267c71a
17 changed files with 291 additions and 311 deletions

View File

@ -1,7 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Optional
from invokeai.app.api.models.images import ProgressImage
from typing import Any
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp

View File

@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin
from send2trash import send2trash
from invokeai.app.models.image import ImageType
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC):
def save(
self,
image: PILImageType,
image_type: ImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase):
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(
for image_origin in ResourceOrigin:
Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_type, image_name)
image_path = self.get_path(image_origin, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save(
self,
image: PILImageType,
image_type: ImageType,
image_origin: ResourceOrigin,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_type, image_name)
image_path = self.get_path(image_origin, image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e:
raise ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None:
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
image_path = self.get_path(image_origin, basename)
if os.path.exists(image_path):
send2trash(image_path)
@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_type, "thumbnails", thumbnail_name
self.__output_folder, image_origin, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
path = os.path.join(self.__output_folder, image_origin, basename)
abspath = os.path.abspath(path)

View File

@ -8,7 +8,7 @@ from typing import Optional, Union
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import (
ImageCategory,
ImageType,
ResourceOrigin,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
@ -46,7 +46,7 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@ -54,7 +54,7 @@ class ImageRecordStorageBase(ABC):
def update(
self,
image_name: str,
image_type: ImageType,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
@ -65,10 +65,10 @@ class ImageRecordStorageBase(ABC):
self,
page: int = 0,
per_page: int = 10,
image_type: Optional[ImageType] = None,
image_category: Optional[ImageCategory] = None,
image_origin: Optional[ResourceOrigin] = None,
include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@ -76,7 +76,7 @@ class ImageRecordStorageBase(ABC):
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image record."""
pass
@ -84,7 +84,7 @@ class ImageRecordStorageBase(ABC):
def save(
self,
image_name: str,
image_type: ImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
width: int,
height: int,
@ -92,7 +92,6 @@ class ImageRecordStorageBase(ABC):
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
show_in_gallery: bool = True,
) -> datetime:
"""Saves an image record."""
pass
@ -131,7 +130,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_type TEXT NOT NULL,
image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
@ -139,7 +138,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id TEXT,
node_id TEXT,
metadata TEXT,
show_in_gallery BOOLEAN DEFAULT TRUE,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
@ -158,7 +156,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
"""
)
self._cursor.execute(
@ -185,7 +183,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
def get(
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
@ -212,7 +212,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def update(
self,
image_name: str,
image_type: ImageType,
image_origin: ResourceOrigin,
changes: ImageRecordChanges,
) -> None:
try:
@ -249,71 +249,72 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self,
page: int = 0,
per_page: int = 10,
image_type: Optional[ImageType] = None,
image_category: Optional[ImageCategory] = None,
image_origin: Optional[ResourceOrigin] = None,
include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageRecord]:
try:
self._lock.acquire()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*) FROM images WHERE 1=1
"""
images_query = """--sql
SELECT * FROM images WHERE 1=1
"""
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
query_conditions = ""
query_params = []
if image_type is not None:
query_conditions += """--sql
AND image_type = ?
"""
query_params.append(image_type.value)
if image_origin is not None:
query_conditions += f"""AND image_origin = ?\n"""
query_params.append(image_origin.value)
if image_category is not None:
query_conditions += """--sql
AND image_category = ?
"""
query_params.append(image_category.value)
if include_categories is not None:
## Convert the enum values to unique list of strings
include_category_strings = list(
map(lambda c: c.value, set(include_categories))
)
# Create the correct length of placeholders
placeholders = ",".join("?" * len(include_category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
# Unpack the included categories into the query params
query_params.append(*include_category_strings)
if exclude_categories is not None:
## Convert the enum values to unique list of strings
exclude_category_strings = list(
map(lambda c: c.value, set(exclude_categories))
)
# Create the correct length of placeholders
placeholders = ",".join("?" * len(exclude_category_strings))
query_conditions += f"AND image_category NOT IN ( {placeholders} )\n"
# Unpack the included categories into the query params
query_params.append(*exclude_category_strings)
if is_intermediate is not None:
query_conditions += """--sql
AND is_intermediate = ?
"""
query_conditions += f"""AND is_intermediate = ?\n"""
query_params.append(is_intermediate)
if show_in_gallery is not None:
query_conditions += """--sql
AND show_in_gallery = ?
"""
query_params.append(show_in_gallery)
query_pagination = """--sql
ORDER BY created_at DESC LIMIT ? OFFSET ?
"""
count_query += query_conditions + ";"
count_params = query_params.copy()
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
images_params.append(per_page)
images_params.append(page * per_page)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
@ -327,7 +328,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
items=images, page=page, pages=pageCount, per_page=per_page, total=count
)
def delete(self, image_type: ImageType, image_name: str) -> None:
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
@ -347,7 +348,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def save(
self,
image_name: str,
image_type: ImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
session_id: Optional[str],
width: int,
@ -355,7 +356,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
show_in_gallery: bool = True,
) -> datetime:
try:
metadata_json = (
@ -366,21 +366,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_type,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
show_in_gallery
is_intermediate
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_type.value,
image_origin.value,
image_category.value,
width,
height,
@ -388,7 +387,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id,
metadata_json,
is_intermediate,
show_in_gallery,
),
)
self._conn.commit()

View File

@ -5,9 +5,9 @@ from PIL.Image import Image as PILImageType
from invokeai.app.models.image import (
ImageCategory,
ImageType,
ResourceOrigin,
InvalidImageCategoryException,
InvalidImageTypeException,
InvalidOriginException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import (
@ -44,12 +44,11 @@ class ImageServiceABC(ABC):
def create(
self,
image: PILImageType,
image_type: ImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
intermediate: bool = False,
show_in_gallery: bool = True,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -57,7 +56,7 @@ class ImageServiceABC(ABC):
@abstractmethod
def update(
self,
image_type: ImageType,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
"""Gets an image's path."""
pass
@ -91,7 +90,7 @@ class ImageServiceABC(ABC):
@abstractmethod
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@ -101,16 +100,16 @@ class ImageServiceABC(ABC):
self,
page: int = 0,
per_page: int = 10,
image_type: Optional[ImageType] = None,
image_category: Optional[ImageCategory] = None,
image_origin: Optional[ResourceOrigin] = None,
include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str):
def delete(self, image_origin: ResourceOrigin, image_name: str):
"""Deletes an image."""
pass
@ -171,15 +170,14 @@ class ImageService(ImageServiceABC):
def create(
self,
image: PILImageType,
image_type: ImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
is_intermediate: bool = False,
show_in_gallery: bool = True,
) -> ImageDTO:
if image_type not in ImageType:
raise InvalidImageTypeException
if image_origin not in ResourceOrigin:
raise InvalidOriginException
if image_category not in ImageCategory:
raise InvalidImageCategoryException
@ -195,13 +193,12 @@ class ImageService(ImageServiceABC):
created_at = self._services.records.save(
# Non-nullable fields
image_name=image_name,
image_type=image_type,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Meta fields
is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
# Nullable fields
node_id=node_id,
session_id=session_id,
@ -209,21 +206,21 @@ class ImageService(ImageServiceABC):
)
self._services.files.save(
image_type=image_type,
image_origin=image_origin,
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_type, image_name)
image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url(
image_type, image_name, True
image_origin, image_name, True
)
return ImageDTO(
# Non-nullable fields
image_name=image_name,
image_type=image_type,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
@ -236,7 +233,6 @@ class ImageService(ImageServiceABC):
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
@ -253,13 +249,13 @@ class ImageService(ImageServiceABC):
def update(
self,
image_type: ImageType,
image_origin: ResourceOrigin,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, image_type, changes)
return self.get_dto(image_type, image_name)
self._services.records.update(image_name, image_origin, changes)
return self.get_dto(image_origin, image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
@ -267,9 +263,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_type, image_name)
return self._services.files.get(image_origin, image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
@ -277,9 +273,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_type, image_name)
return self._services.records.get(image_origin, image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
@ -287,14 +283,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_type, image_name)
image_record = self._services.records.get(image_origin, image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_type, image_name),
self._services.urls.get_image_url(image_type, image_name, True),
self._services.urls.get_image_url(image_origin, image_name),
self._services.urls.get_image_url(image_origin, image_name, True),
)
return image_dto
@ -306,10 +302,10 @@ class ImageService(ImageServiceABC):
raise e
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_type, image_name, thumbnail)
return self._services.files.get_path(image_origin, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@ -322,10 +318,10 @@ class ImageService(ImageServiceABC):
raise e
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
@ -334,28 +330,28 @@ class ImageService(ImageServiceABC):
self,
page: int = 0,
per_page: int = 10,
image_type: Optional[ImageType] = None,
image_category: Optional[ImageCategory] = None,
image_origin: Optional[ResourceOrigin] = None,
include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
page,
per_page,
image_type,
image_category,
image_origin,
include_categories,
exclude_categories,
is_intermediate,
show_in_gallery,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_type, r.image_name),
self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url(
r.image_type, r.image_name, True
r.image_origin, r.image_name, True
),
),
results.items,
@ -373,10 +369,10 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_type: ImageType, image_name: str):
def delete(self, image_origin: ResourceOrigin, image_name: str):
try:
self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name)
self._services.files.delete(image_origin, image_name)
self._services.records.delete(image_origin, image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise

View File

@ -1,7 +1,7 @@
import datetime
from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictStr
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
@ -11,8 +11,8 @@ class ImageRecord(BaseModel):
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.")
"""The type of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image."""
width: int = Field(description="The width of the image in px.")
@ -33,8 +33,6 @@ class ImageRecord(BaseModel):
"""The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image."""
show_in_gallery: bool = Field(description="Whether this image should be shown in the gallery.")
"""Whether this image should be shown in the gallery."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this image, if it is a generated image.",
@ -76,8 +74,8 @@ class ImageUrlsDTO(BaseModel):
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.")
"""The type of the image."""
image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The origin of the image."""
image_url: str = Field(description="The URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
@ -107,7 +105,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown")
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)
@ -119,7 +117,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
show_in_gallery = image_dict.get("show_in_gallery", True)
raw_metadata = image_dict.get("metadata")
@ -130,7 +127,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
return ImageRecord(
image_name=image_name,
image_type=image_type,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
@ -141,5 +138,4 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at=updated_at,
deleted_at=deleted_at,
is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
)

View File

@ -1,7 +1,7 @@
import os
from abc import ABC, abstractmethod
from invokeai.app.models.image import ImageType
from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name
@ -10,7 +10,7 @@ class UrlServiceBase(ABC):
@abstractmethod
def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase):
self._base_url = base_url
def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str:
image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_type.value}/{image_basename}"
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"