mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): refactor image types
- Remove `ImageType` entirely, it is confusing - Create `ResourceOrigin`, may be `internal` or `external` - Revamp `ImageCategory`, may be `general`, `mask`, `control`, `user`, `other`. Expect to add more as time goes on - Update images `list` route to accept `include_categories` OR `exclude_categories` query parameters to afford finer-grained querying. All services are updated to accomodate this change. The new setup should account for our types of images, including the combinations we couldn't really handle until now: - Canvas init and masks - Canvas when saved-to-gallery or merged
This commit is contained in:
committed by
Kent Keirsey
parent
fd47e70c92
commit
160267c71a
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Optional
|
||||
from invokeai.app.api.models.images import ProgressImage
|
||||
from typing import Any
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType
|
||||
from PIL import Image, PngImagePlugin
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the internal path to an image or thumbnail."""
|
||||
pass
|
||||
@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC):
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||
for image_type in ImageType:
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
for image_origin in ResourceOrigin:
|
||||
Path(os.path.join(output_folder, image_origin)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||
Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
cache_item = self.__get_cache(image_path)
|
||||
if cache_item:
|
||||
return cache_item
|
||||
@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
metadata: Optional[ImageMetadata] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
image_path = self.get_path(image_origin, image_name)
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image.save(image_path, "PNG")
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
|
||||
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||
thumbnail_image.save(thumbnail_path)
|
||||
|
||||
@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
except Exception as e:
|
||||
raise ImageFileSaveException from e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
try:
|
||||
basename = os.path.basename(image_name)
|
||||
image_path = self.get_path(image_type, basename)
|
||||
image_path = self.get_path(image_origin, basename)
|
||||
|
||||
if os.path.exists(image_path):
|
||||
send2trash(image_path)
|
||||
@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
del self.__cache[image_path]
|
||||
|
||||
thumbnail_name = get_thumbnail_name(image_name)
|
||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
||||
thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
send2trash(thumbnail_path)
|
||||
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
# strip out any relative path shenanigans
|
||||
basename = os.path.basename(image_name)
|
||||
@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
if thumbnail:
|
||||
thumbnail_name = get_thumbnail_name(basename)
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", thumbnail_name
|
||||
self.__output_folder, image_origin, "thumbnails", thumbnail_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, basename)
|
||||
path = os.path.join(self.__output_folder, image_origin, basename)
|
||||
|
||||
abspath = os.path.abspath(path)
|
||||
|
||||
|
@ -8,7 +8,7 @@ from typing import Optional, Union
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
@ -46,7 +46,7 @@ class ImageRecordStorageBase(ABC):
|
||||
# TODO: Implement an `update()` method
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@ -54,7 +54,7 @@ class ImageRecordStorageBase(ABC):
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
"""Updates an image record."""
|
||||
@ -65,10 +65,10 @@ class ImageRecordStorageBase(ABC):
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
image_type: Optional[ImageType] = None,
|
||||
image_category: Optional[ImageCategory] = None,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
include_categories: Optional[list[ImageCategory]] = None,
|
||||
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
show_in_gallery: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
@ -76,7 +76,7 @@ class ImageRecordStorageBase(ABC):
|
||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
"""Deletes an image record."""
|
||||
pass
|
||||
|
||||
@ -84,7 +84,7 @@ class ImageRecordStorageBase(ABC):
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
@ -92,7 +92,6 @@ class ImageRecordStorageBase(ABC):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
show_in_gallery: bool = True,
|
||||
) -> datetime:
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
@ -131,7 +130,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
CREATE TABLE IF NOT EXISTS images (
|
||||
image_name TEXT NOT NULL PRIMARY KEY,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_type TEXT NOT NULL,
|
||||
image_origin TEXT NOT NULL,
|
||||
-- This is an enum in python, unrestricted string here for flexibility
|
||||
image_category TEXT NOT NULL,
|
||||
width INTEGER NOT NULL,
|
||||
@ -139,7 +138,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id TEXT,
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
show_in_gallery BOOLEAN DEFAULT TRUE,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
@ -158,7 +156,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
@ -185,7 +183,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
|
||||
def get(
|
||||
self, image_origin: ResourceOrigin, image_name: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
@ -212,7 +212,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def update(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
@ -249,71 +249,72 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
image_type: Optional[ImageType] = None,
|
||||
image_category: Optional[ImageCategory] = None,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
include_categories: Optional[list[ImageCategory]] = None,
|
||||
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
show_in_gallery: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = """--sql
|
||||
SELECT * FROM images WHERE 1=1
|
||||
"""
|
||||
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
||||
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
|
||||
if image_type is not None:
|
||||
query_conditions += """--sql
|
||||
AND image_type = ?
|
||||
"""
|
||||
query_params.append(image_type.value)
|
||||
if image_origin is not None:
|
||||
query_conditions += f"""AND image_origin = ?\n"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if image_category is not None:
|
||||
query_conditions += """--sql
|
||||
AND image_category = ?
|
||||
"""
|
||||
query_params.append(image_category.value)
|
||||
if include_categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
include_category_strings = list(
|
||||
map(lambda c: c.value, set(include_categories))
|
||||
)
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(include_category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
query_params.append(*include_category_strings)
|
||||
|
||||
if exclude_categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
exclude_category_strings = list(
|
||||
map(lambda c: c.value, set(exclude_categories))
|
||||
)
|
||||
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(exclude_category_strings))
|
||||
query_conditions += f"AND image_category NOT IN ( {placeholders} )\n"
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
query_params.append(*exclude_category_strings)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND is_intermediate = ?
|
||||
"""
|
||||
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if show_in_gallery is not None:
|
||||
query_conditions += """--sql
|
||||
AND show_in_gallery = ?
|
||||
"""
|
||||
query_params.append(show_in_gallery)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
images_params.append(per_page)
|
||||
images_params.append(page * per_page)
|
||||
|
||||
# Build the list of images, deserializing each row
|
||||
self._cursor.execute(images_query, images_params)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
|
||||
count = self._cursor.fetchone()[0]
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -327,7 +328,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
items=images, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@ -347,7 +348,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
session_id: Optional[str],
|
||||
width: int,
|
||||
@ -355,7 +356,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id: Optional[str],
|
||||
metadata: Optional[ImageMetadata],
|
||||
is_intermediate: bool = False,
|
||||
show_in_gallery: bool = True,
|
||||
) -> datetime:
|
||||
try:
|
||||
metadata_json = (
|
||||
@ -366,21 +366,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
image_type,
|
||||
image_origin,
|
||||
image_category,
|
||||
width,
|
||||
height,
|
||||
node_id,
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
show_in_gallery
|
||||
is_intermediate
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
image_type.value,
|
||||
image_origin.value,
|
||||
image_category.value,
|
||||
width,
|
||||
height,
|
||||
@ -388,7 +387,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id,
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
show_in_gallery,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
@ -5,9 +5,9 @@ from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ImageType,
|
||||
ResourceOrigin,
|
||||
InvalidImageCategoryException,
|
||||
InvalidImageTypeException,
|
||||
InvalidOriginException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
@ -44,12 +44,11 @@ class ImageServiceABC(ABC):
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
intermediate: bool = False,
|
||||
show_in_gallery: bool = True,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@ -57,7 +56,7 @@ class ImageServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
"""Gets an image as a PIL image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
"""Gets an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
"""Gets an image DTO."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@ -91,7 +90,7 @@ class ImageServiceABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets an image's or thumbnail's URL."""
|
||||
pass
|
||||
@ -101,16 +100,16 @@ class ImageServiceABC(ABC):
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
image_type: Optional[ImageType] = None,
|
||||
image_category: Optional[ImageCategory] = None,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
include_categories: Optional[list[ImageCategory]] = None,
|
||||
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
show_in_gallery: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_type: ImageType, image_name: str):
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
"""Deletes an image."""
|
||||
pass
|
||||
|
||||
@ -171,15 +170,14 @@ class ImageService(ImageServiceABC):
|
||||
def create(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
is_intermediate: bool = False,
|
||||
show_in_gallery: bool = True,
|
||||
) -> ImageDTO:
|
||||
if image_type not in ImageType:
|
||||
raise InvalidImageTypeException
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
|
||||
if image_category not in ImageCategory:
|
||||
raise InvalidImageCategoryException
|
||||
@ -195,13 +193,12 @@ class ImageService(ImageServiceABC):
|
||||
created_at = self._services.records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
show_in_gallery=show_in_gallery,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
@ -209,21 +206,21 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
image_type=image_type,
|
||||
image_origin=image_origin,
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_type, image_name)
|
||||
image_url = self._services.urls.get_image_url(image_origin, image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(
|
||||
image_type, image_name, True
|
||||
image_origin, image_name, True
|
||||
)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
@ -236,7 +233,6 @@ class ImageService(ImageServiceABC):
|
||||
updated_at=created_at, # this is always the same as the created_at at this time
|
||||
deleted_at=None,
|
||||
is_intermediate=is_intermediate,
|
||||
show_in_gallery=show_in_gallery,
|
||||
# Extra non-nullable fields for DTO
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
@ -253,13 +249,13 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def update(
|
||||
self,
|
||||
image_type: ImageType,
|
||||
image_origin: ResourceOrigin,
|
||||
image_name: str,
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.records.update(image_name, image_type, changes)
|
||||
return self.get_dto(image_type, image_name)
|
||||
self._services.records.update(image_name, image_origin, changes)
|
||||
return self.get_dto(image_origin, image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
raise
|
||||
@ -267,9 +263,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem updating image record")
|
||||
raise e
|
||||
|
||||
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||
def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_type, image_name)
|
||||
return self._services.files.get(image_origin, image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
@ -277,9 +273,9 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image file")
|
||||
raise e
|
||||
|
||||
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||
def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_type, image_name)
|
||||
return self._services.records.get(image_origin, image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
@ -287,14 +283,14 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting image record")
|
||||
raise e
|
||||
|
||||
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||
def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_type, image_name)
|
||||
image_record = self._services.records.get(image_origin, image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_type, image_name),
|
||||
self._services.urls.get_image_url(image_type, image_name, True),
|
||||
self._services.urls.get_image_url(image_origin, image_name),
|
||||
self._services.urls.get_image_url(image_origin, image_name, True),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@ -306,10 +302,10 @@ class ImageService(ImageServiceABC):
|
||||
raise e
|
||||
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_type, image_name, thumbnail)
|
||||
return self._services.files.get_path(image_origin, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
@ -322,10 +318,10 @@ class ImageService(ImageServiceABC):
|
||||
raise e
|
||||
|
||||
def get_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
try:
|
||||
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
||||
return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
@ -334,28 +330,28 @@ class ImageService(ImageServiceABC):
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
image_type: Optional[ImageType] = None,
|
||||
image_category: Optional[ImageCategory] = None,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
include_categories: Optional[list[ImageCategory]] = None,
|
||||
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
show_in_gallery: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
page,
|
||||
per_page,
|
||||
image_type,
|
||||
image_category,
|
||||
image_origin,
|
||||
include_categories,
|
||||
exclude_categories,
|
||||
is_intermediate,
|
||||
show_in_gallery,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_type, r.image_name),
|
||||
self._services.urls.get_image_url(r.image_origin, r.image_name),
|
||||
self._services.urls.get_image_url(
|
||||
r.image_type, r.image_name, True
|
||||
r.image_origin, r.image_name, True
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
@ -373,10 +369,10 @@ class ImageService(ImageServiceABC):
|
||||
self._services.logger.error("Problem getting paginated image DTOs")
|
||||
raise e
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str):
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str):
|
||||
try:
|
||||
self._services.files.delete(image_type, image_name)
|
||||
self._services.records.delete(image_type, image_name)
|
||||
self._services.files.delete(image_origin, image_name)
|
||||
self._services.records.delete(image_origin, image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
|
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel, Extra, Field, StrictStr
|
||||
from invokeai.app.models.image import ImageCategory, ImageType
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
@ -11,8 +11,8 @@ class ImageRecord(BaseModel):
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_type: ImageType = Field(description="The type of the image.")
|
||||
"""The type of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_category: ImageCategory = Field(description="The category of the image.")
|
||||
"""The category of the image."""
|
||||
width: int = Field(description="The width of the image in px.")
|
||||
@ -33,8 +33,6 @@ class ImageRecord(BaseModel):
|
||||
"""The deleted timestamp of the image."""
|
||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||
"""Whether this is an intermediate image."""
|
||||
show_in_gallery: bool = Field(description="Whether this image should be shown in the gallery.")
|
||||
"""Whether this image should be shown in the gallery."""
|
||||
session_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The session ID that generated this image, if it is a generated image.",
|
||||
@ -76,8 +74,8 @@ class ImageUrlsDTO(BaseModel):
|
||||
|
||||
image_name: str = Field(description="The unique name of the image.")
|
||||
"""The unique name of the image."""
|
||||
image_type: ImageType = Field(description="The type of the image.")
|
||||
"""The type of the image."""
|
||||
image_origin: ResourceOrigin = Field(description="The type of the image.")
|
||||
"""The origin of the image."""
|
||||
image_url: str = Field(description="The URL of the image.")
|
||||
"""The URL of the image."""
|
||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||
@ -107,7 +105,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
|
||||
image_category = ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||
)
|
||||
@ -119,7 +117,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
show_in_gallery = image_dict.get("show_in_gallery", True)
|
||||
|
||||
raw_metadata = image_dict.get("metadata")
|
||||
|
||||
@ -130,7 +127,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
image_type=image_type,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
@ -141,5 +138,4 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
show_in_gallery=show_in_gallery,
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.image import ResourceOrigin
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ class UrlServiceBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_image_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase):
|
||||
self._base_url = base_url
|
||||
|
||||
def get_image_url(
|
||||
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||
self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
|
||||
) -> str:
|
||||
image_basename = os.path.basename(image_name)
|
||||
|
||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||
if thumbnail:
|
||||
return (
|
||||
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
|
||||
f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
|
||||
)
|
||||
|
||||
return f"{self._base_url}/images/{image_type.value}/{image_basename}"
|
||||
return f"{self._base_url}/images/{image_origin.value}/{image_basename}"
|
||||
|
Reference in New Issue
Block a user