InvokeAI/invokeai/app/services/images.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

400 lines
14 KiB
Python
Raw Normal View History

import json
2023-05-21 12:15:44 +00:00
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Optional
from PIL.Image import Image as PILImageType
2023-05-21 12:15:44 +00:00
from invokeai.app.invocations.metadata import ImageMetadata
from invokeai.app.models.image import (ImageCategory,
InvalidImageCategoryException,
InvalidOriginException, ResourceOrigin)
from invokeai.app.services.board_image_record_storage import \
BoardImageRecordStorageBase
from invokeai.app.services.graph import Graph
from invokeai.app.services.image_file_storage import (
ImageFileDeleteException, ImageFileNotFoundException,
ImageFileSaveException, ImageFileStorageBase)
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException, ImageRecordNotFoundException,
ImageRecordSaveException, ImageRecordStorageBase, OffsetPaginatedResults)
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.models.image_record import (ImageDTO, ImageRecord,
ImageRecordChanges,
image_record_to_dto)
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
2023-05-22 05:48:12 +00:00
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
2023-05-21 12:15:44 +00:00
class ImageServiceABC(ABC):
2023-05-22 09:44:35 +00:00
"""High-level service for image management."""
2023-05-21 12:15:44 +00:00
@abstractmethod
def create(
self,
image: PILImageType,
image_origin: ResourceOrigin,
2023-05-21 12:15:44 +00:00
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
2023-05-21 12:15:44 +00:00
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
"""Updates an image."""
pass
2023-05-21 12:15:44 +00:00
@abstractmethod
def get_pil_image(self, image_name: str) -> PILImageType:
2023-05-21 12:15:44 +00:00
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_name: str) -> ImageRecord:
2023-05-21 12:15:44 +00:00
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_metadata(self, image_name: str) -> ImageMetadata:
"""Gets an image's metadata."""
pass
2023-05-21 12:15:44 +00:00
@abstractmethod
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path."""
pass
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image's path."""
2023-05-21 12:15:44 +00:00
pass
@abstractmethod
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's or thumbnail's URL."""
2023-05-21 12:15:44 +00:00
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
2023-05-27 08:32:16 +00:00
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
2023-05-21 12:15:44 +00:00
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_name: str):
2023-05-21 12:15:44 +00:00
"""Deletes an image."""
pass
@abstractmethod
def delete_images_on_board(self, board_id: str):
"""Deletes all images on a board."""
pass
2023-05-21 12:15:44 +00:00
class ImageServiceDependencies:
2023-05-21 12:15:44 +00:00
"""Service dependencies for the ImageService."""
image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
names: NameServiceBase
2023-05-22 05:48:12 +00:00
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
2023-05-22 05:48:12 +00:00
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.image_records = image_record_storage
self.image_files = image_file_storage
self.board_image_records = board_image_record_storage
self.urls = url
self.logger = logger
self.names = names
2023-05-22 05:48:12 +00:00
self.graph_execution_manager = graph_execution_manager
2023-05-21 12:15:44 +00:00
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(self, services: ImageServiceDependencies):
self._services = services
def create(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
if image_category not in ImageCategory:
raise InvalidImageCategoryException
image_name = self._services.names.create_image_name()
graph = None
if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id)
if session_raw is not None:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
2023-05-21 12:15:44 +00:00
(width, height) = image.size
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.image_records.save(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
node_id=node_id,
metadata=metadata,
session_id=session_id,
)
self._services.image_files.save(
image_name=image_name, image=image, metadata=metadata, graph=graph
)
image_dto = self.get_dto(image_name)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
except ImageFileSaveException:
self._services.logger.error("Failed to save image file")
raise
2023-05-21 12:15:44 +00:00
except Exception as e:
self._services.logger.error("Problem saving image record and file")
raise e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
except Exception as e:
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.image_files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
2023-05-21 12:15:44 +00:00
except Exception as e:
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.image_records.get(image_name)
except ImageRecordNotFoundException:
2023-05-21 12:15:44 +00:00
self._services.logger.error("Image record not found")
raise
2023-05-21 12:15:44 +00:00
except Exception as e:
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.image_records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
self._services.board_image_records.get_board_for_image(image_name),
)
return image_dto
except ImageRecordNotFoundException:
2023-05-21 12:15:44 +00:00
self._services.logger.error("Image record not found")
raise
2023-05-21 12:15:44 +00:00
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try:
image_record = self._services.image_records.get(image_name)
if not image_record.session_id:
return ImageMetadata()
session_raw = self._services.graph_execution_manager.get_raw(
image_record.session_id
)
graph = None
if session_raw:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
metadata = self._services.image_records.get_metadata(image_name)
return ImageMetadata(graph=graph, metadata=metadata)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
2023-05-22 05:48:12 +00:00
try:
return self._services.image_files.get_path(image_name, thumbnail)
2023-05-22 05:48:12 +00:00
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.image_files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
2023-05-22 05:48:12 +00:00
try:
return self._services.urls.get_image_url(image_name, thumbnail)
2023-05-22 05:48:12 +00:00
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
2023-05-27 08:32:16 +00:00
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.image_records.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(
r.image_name
),
),
results.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=results.offset,
limit=results.limit,
total=results.total,
)
except Exception as e:
2023-05-21 12:15:44 +00:00
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_name: str):
2023-05-21 12:15:44 +00:00
try:
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
except ImageRecordDeleteException:
2023-05-21 12:15:44 +00:00
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileDeleteException:
2023-05-21 12:15:44 +00:00
self._services.logger.error(f"Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
raise e
def delete_images_on_board(self, board_id: str):
try:
images = self._services.board_image_records.get_images_for_board(board_id)
image_name_list = list(
map(
lambda r: r.image_name,
images.items,
)
)
for image_name in image_name_list:
self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_name_list)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image records")
raise
except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image files")
raise
except Exception as e:
self._services.logger.error("Problem deleting image records and files")
raise e