InvokeAI/invokeai/app/services/images.py
psychedelicious a1773197e9 feat(nodes): remove image_origin from most places
- remove `image_origin` from most places where we interact with images
- consolidate image file storage into a single `images/` dir

Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads.

It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images.

The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful.

So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
2023-06-14 23:08:27 +10:00

381 lines
12 KiB
Python

from abc import ABC, abstractmethod
from logging import Logger
from typing import Optional, TYPE_CHECKING, Union
from PIL.Image import Image as PILImageType
from invokeai.app.models.image import (
ImageCategory,
ResourceOrigin,
InvalidImageCategoryException,
InvalidOriginException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
ImageRecordChanges,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import (
ImageFileDeleteException,
ImageFileNotFoundException,
ImageFileSaveException,
ImageFileStorageBase,
)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
class ImageServiceABC(ABC):
"""High-level service for image management."""
@abstractmethod
def create(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
intermediate: bool = False,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
"""Updates an image."""
pass
@abstractmethod
def get_pil_image(self, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_name: str) -> str:
"""Gets an image's path."""
pass
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image's path."""
pass
@abstractmethod
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_name: str):
"""Deletes an image."""
pass
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
records: ImageRecordStorageBase
files: ImageFileStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.records = image_record_storage
self.files = image_file_storage
self.metadata = metadata
self.urls = url
self.logger = logger
self.names = names
self.graph_execution_manager = graph_execution_manager
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
def create(
self,
image: PILImageType,
image_origin: ResourceOrigin,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
is_intermediate: bool = False,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
if image_category not in ImageCategory:
raise InvalidImageCategoryException
image_name = self._services.names.create_image_name()
metadata = self._get_metadata(session_id, node_id)
(width, height) = image.size
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
)
self._services.files.save(
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_name)
thumbnail_url = self._services.urls.get_image_url(image_name, True)
return ImageDTO(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
# Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
except ImageFileSaveException:
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error("Problem saving image record and file")
raise e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
except Exception as e:
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
except Exception as e:
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
)
return image_dto
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.urls.get_image_url(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_many(
self,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
),
results.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=results.offset,
limit=results.limit,
total=results.total,
)
except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_name: str):
try:
self._services.files.delete(image_name)
self._services.records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
raise e
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]:
"""Get the metadata for a node."""
metadata = None
if node_id is not None and session_id is not None:
session = self._services.graph_execution_manager.get(session_id)
metadata = self._services.metadata.create_image_metadata(session, node_id)
return metadata