feat(nodes): wip image storage implementation

This commit is contained in:
psychedelicious 2023-05-21 20:05:33 +10:00 committed by Kent Keirsey
parent d4aa79acd7
commit 1b75d899ae
13 changed files with 383 additions and 268 deletions

View File

@ -2,9 +2,8 @@
import os
from types import ModuleType
from invokeai.app.services.database.images.sqlite_images_db_service import (
SqliteImageDb,
)
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.urls import LocalUrlService
import invokeai.backend.util.logging as logger
@ -14,7 +13,7 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_storage import DiskImageStorage
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
@ -63,7 +62,9 @@ class ApiDependencies:
urls = LocalUrlService()
images = DiskImageStorage(f"{output_folder}/images", metadata_service=metadata)
image_file_storage = DiskImageFileStorage(
f"{output_folder}/images", metadata_service=metadata
)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
@ -72,7 +73,14 @@ class ApiDependencies:
filename=db_location, table_name="graph_executions"
)
images_db = SqliteImageDb(filename=db_location)
image_record_storage = SqliteImageRecordStorage(db_location)
images_new = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
)
# register event handler to update the `results` table when a graph execution state is inserted or updated
# graph_execution_manager.on_changed(results.handle_graph_execution_state_change)
@ -82,8 +90,8 @@ class ApiDependencies:
events=events,
latents=latents,
images=images,
images_new=images_new,
metadata=metadata,
images_db=images_db,
urls=urls,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](

View File

@ -0,0 +1,165 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
# @images_router.get("/{image_type}/{image_name}", operation_id="get_image")
# async def get_image(
# image_type: ImageType = Path(description="The type of image to get"),
# image_name: str = Path(description="The name of the image to get"),
# ) -> FileResponse:
# """Gets an image"""
# path = ApiDependencies.invoker.services.images.get_path(
# image_type=image_type, image_name=image_name
# )
# if ApiDependencies.invoker.services.images.validate_path(path):
# return FileResponse(path)
# else:
# raise HTTPException(status_code=404)
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of the image to get"),
image_name: str = Path(description="The id of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of the image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail"
)
async def get_thumbnail(
image_type: ImageType = Path(description="The type of the thumbnail to get"),
thumbnail_id: str = Path(description="The id of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=thumbnail_id, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/uploads/",
operation_id="upload_image",
responses={
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"},
},
status_code=201,
)
async def upload_image(
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
except:
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
saved_image = ApiDependencies.invoker.services.images.save(
image_type, filename, img
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = image_url
return res
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result

View File

@ -4,28 +4,28 @@ from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.image_db import ImageRecordServiceBase
from invokeai.app.services.image_storage import ImageStorageBase
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import ImageDTO
from ..dependencies import ApiDependencies
image_records_router = APIRouter(prefix="/v1/images", tags=["images", "records"])
image_records_router = APIRouter(
prefix="/v1/images/records", tags=["images", "records"]
)
@image_records_router.get("/{image_type}/{image_name}", operation_id="get_image_record")
async def get_image_record(
image_type: ImageType = Path(description="The type of the image record to get"),
image_name: str = Path(description="The id of the image record to get"),
) -> ImageRecord:
) -> ImageDTO:
"""Gets an image record by id"""
try:
return ApiDependencies.invoker.services.images_new.get_record(
return ApiDependencies.invoker.services.images_new.get_dto(
image_type=image_type, image_name=image_name
)
except ImageRecordServiceBase.ImageRecordNotFoundException:
except Exception as e:
raise HTTPException(status_code=404)
@ -42,17 +42,17 @@ async def list_image_records(
per_page: int = Query(
default=10, description="The number of image records per page"
),
) -> PaginatedResults[ImageRecord]:
) -> PaginatedResults[ImageDTO]:
"""Gets a list of image records by type and category"""
images = ApiDependencies.invoker.services.images_new.get_many(
image_dtos = ApiDependencies.invoker.services.images_new.get_many(
image_type=image_type,
image_category=image_category,
page=page,
per_page=per_page,
)
return images
return image_dtos
@image_records_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
@ -66,9 +66,6 @@ async def delete_image_record(
ApiDependencies.invoker.services.images_new.delete(
image_type=image_type, image_name=image_name
)
except ImageStorageBase.ImageFileDeleteException:
# TODO: log this
pass
except ImageRecordServiceBase.ImageRecordDeleteException:
# TODO: log this
except Exception as e:
# TODO: Does this need any exception handling at all?
pass

View File

@ -1,107 +1,39 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
# @images_router.get("/{image_type}/{image_name}", operation_id="get_image")
# async def get_image(
# image_type: ImageType = Path(description="The type of image to get"),
# image_name: str = Path(description="The name of the image to get"),
# ) -> FileResponse:
# """Gets an image"""
# path = ApiDependencies.invoker.services.images.get_path(
# image_type=image_type, image_name=image_name
# )
# if ApiDependencies.invoker.services.images.validate_path(path):
# return FileResponse(path)
# else:
# raise HTTPException(status_code=404)
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of the image to get"),
image_name: str = Path(description="The id of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of the image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail"
)
async def get_thumbnail(
image_type: ImageType = Path(description="The type of the thumbnail to get"),
thumbnail_id: str = Path(description="The id of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=thumbnail_id, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.post(
"/uploads/",
"/",
operation_id="upload_image",
responses={
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
201: {"description": "The image was uploaded successfully"},
415: {"description": "Image upload failed"},
},
status_code=201,
)
async def upload_image(
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
file: UploadFile,
image_type: ImageType,
request: Request,
response: Response,
image_category: ImageCategory = ImageCategory.IMAGE,
) -> ImageRecord:
"""Uploads an image"""
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@ -113,53 +45,33 @@ async def upload_image(
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
try:
image_record = ApiDependencies.invoker.services.images_new.create(
image=img,
image_type=image_type,
image_category=image_category,
)
saved_image = ApiDependencies.invoker.services.images.save(
image_type, filename, img
)
response.status_code = 201
response.headers["Location"] = image_record.image_url
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = image_url
return res
return image_record
except Exception as e:
raise HTTPException(status_code=500)
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image_record(
image_type: ImageType = Query(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image record"""
try:
ApiDependencies.invoker.services.images_new.delete(
image_type=image_type, image_name=image_name
)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass

View File

@ -15,7 +15,7 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema
from .api.dependencies import ApiDependencies
from .api.routers import image_records, images, sessions, models
from .api.routers import image_files, image_records, sessions, models
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig
@ -71,7 +71,7 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(image_files.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")

View File

@ -28,7 +28,7 @@ from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
@ -215,7 +215,7 @@ def invoke_cli():
model_manager=model_manager,
events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
images=DiskImageFileStorage(f'{output_folder}/images', metadata_service=metadata),
metadata=metadata,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](

View File

@ -26,8 +26,8 @@ from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageStorageBase(ABC):
"""Low-level service responsible for storing and retrieving images."""
class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
@ -75,19 +75,19 @@ class ImageStorageBase(ABC):
"""Gets the external URI to an image or its thumbnail."""
pass
@abstractmethod
def get_image_location(
self, image_type: ImageType, image_name: str
) -> str:
"""Gets the location of an image."""
pass
# @abstractmethod
# def get_image_location(
# self, image_type: ImageType, image_name: str
# ) -> str:
# """Gets the location of an image."""
# pass
@abstractmethod
def get_thumbnail_location(
self, image_type: ImageType, image_name: str
) -> str:
"""Gets the location of an image's thumbnail."""
pass
# @abstractmethod
# def get_thumbnail_location(
# self, image_type: ImageType, image_name: str
# ) -> str:
# """Gets the location of an image's thumbnail."""
# pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
@ -116,7 +116,7 @@ class ImageStorageBase(ABC):
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageStorage(ImageStorageBase):
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
@ -206,7 +206,7 @@ class DiskImageStorage(ImageStorageBase):
self.__set_cache(image_path, image)
return image
except Exception as e:
raise ImageStorageBase.ImageFileNotFoundException from e
raise ImageFileStorageBase.ImageFileNotFoundException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
@ -282,7 +282,7 @@ class DiskImageStorage(ImageStorageBase):
created=int(os.path.getctime(image_path)),
)
except Exception as e:
raise ImageStorageBase.ImageFileSaveException from e
raise ImageFileStorageBase.ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None:
try:
@ -302,7 +302,7 @@ class DiskImageStorage(ImageStorageBase):
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:
raise ImageStorageBase.ImageFileDeleteException from e
raise ImageFileStorageBase.ImageFileDeleteException from e
def __get_cache(self, image_name: str) -> Image | None:
return None if image_name not in self.__cache else self.__cache[image_name]

View File

@ -26,7 +26,7 @@ from invokeai.app.services.util.deserialize_image_record import (
from invokeai.app.services.item_storage import PaginatedResults
class ImageRecordServiceBase(ABC):
class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store."""
class ImageRecordNotFoundException(Exception):
@ -85,7 +85,7 @@ class ImageRecordServiceBase(ABC):
pass
class SqliteImageRecordService(ImageRecordServiceBase):
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
@ -277,7 +277,7 @@ class SqliteImageRecordService(ImageRecordServiceBase):
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordServiceBase.ImageRecordDeleteException from e
raise ImageRecordStorageBase.ImageRecordDeleteException from e
finally:
self._lock.release()
@ -324,6 +324,6 @@ class SqliteImageRecordService(ImageRecordServiceBase):
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordServiceBase.ImageRecordNotFoundException from e
raise ImageRecordStorageBase.ImageRecordNotFoundException from e
finally:
self._lock.release()

View File

@ -1,4 +1,4 @@
from typing import Union
from typing import Optional, Union
import uuid
from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ImageCategory, ImageType
@ -6,11 +6,15 @@ from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.services.image_db import (
ImageRecordServiceBase,
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
)
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.image_storage import ImageStorageBase
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase
@ -20,22 +24,22 @@ from invokeai.app.util.misc import get_iso_timestamp
class ImageServiceDependencies:
"""Service dependencies for the ImageManagementService."""
db: ImageRecordServiceBase
storage: ImageStorageBase
records: ImageRecordStorageBase
files: ImageFileStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
def __init__(
self,
image_db_service: ImageRecordServiceBase,
image_storage_service: ImageStorageBase,
image_metadata_service: MetadataServiceBase,
url_service: UrlServiceBase,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
):
self.db = image_db_service
self.storage = image_storage_service
self.metadata = image_metadata_service
self.url = url_service
self.records = image_record_storage
self.files = image_file_storage
self.metadata = metadata
self.urls = url
class ImageService:
@ -45,24 +49,24 @@ class ImageService:
def __init__(
self,
image_db_service: ImageRecordServiceBase,
image_storage_service: ImageStorageBase,
image_metadata_service: MetadataServiceBase,
url_service: UrlServiceBase,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
):
self._services = ImageServiceDependencies(
image_db_service=image_db_service,
image_storage_service=image_storage_service,
image_metadata_service=image_metadata_service,
url_service=url_service,
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
)
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Union[str, None],
session_id: Union[str, None],
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Creates an image name."""
uuid_str = str(uuid.uuid4())
@ -77,12 +81,12 @@ class ImageService:
image: PILImageType,
image_type: ImageType,
image_category: ImageCategory,
node_id: Union[str, None],
session_id: Union[str, None],
metadata: Union[
GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None
],
) -> ImageRecord:
node_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
image_name = self._create_image_name(
image_type=image_type,
@ -95,14 +99,14 @@ class ImageService:
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.storage.save(
self._services.files.save(
image_type=image_type,
image_name=image_name,
image=image,
metadata=metadata,
)
self._services.db.save(
self._services.records.save(
image_name=image_name,
image_type=image_type,
image_category=image_category,
@ -112,15 +116,10 @@ class ImageService:
created_at=timestamp,
)
image_url = self._services.url.get_image_url(
image_type=image_type, image_name=image_name
)
image_url = self._services.urls.get_image_url(image_type, image_name)
thumbnail_url = self._services.urls.get_thumbnail_url(image_type, image_name)
thumbnail_url = self._services.url.get_thumbnail_url(
image_type=image_type, image_name=image_name
)
return ImageRecord(
return ImageDTO(
image_name=image_name,
image_type=image_type,
image_category=image_category,
@ -131,32 +130,42 @@ class ImageService:
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except ImageRecordServiceBase.ImageRecordSaveException:
except ImageRecordStorageBase.ImageRecordSaveException:
# TODO: log this
raise
except ImageStorageBase.ImageFileSaveException:
except ImageFileStorageBase.ImageFileSaveException:
# TODO: log this
raise
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
try:
pil_image = self._services.storage.get(
image_type=image_type, image_name=image_name
)
return pil_image
except ImageStorageBase.ImageFileNotFoundException:
return self._services.files.get(image_type, image_name)
except ImageFileStorageBase.ImageFileNotFoundException:
# TODO: log this
raise
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
try:
image_record = self._services.db.get(
image_type=image_type, image_name=image_name
return self._services.records.get(image_type, image_name)
except ImageRecordStorageBase.ImageRecordNotFoundException:
# TODO: log this
raise
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
try:
image_record = self._services.records.get(image_type, image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_type, image_name),
self._services.urls.get_thumbnail_url(image_type, image_name),
)
return image_record
except ImageRecordServiceBase.ImageRecordNotFoundException:
return image_dto
except ImageRecordStorageBase.ImageRecordNotFoundException:
# TODO: log this
raise
@ -164,12 +173,12 @@ class ImageService:
"""Deletes an image."""
# TODO: Consider using a transaction here to ensure consistency between storage and database
try:
self._services.storage.delete(image_type=image_type, image_name=image_name)
self._services.db.delete(image_type=image_type, image_name=image_name)
except ImageRecordServiceBase.ImageRecordDeleteException:
self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name)
except ImageRecordStorageBase.ImageRecordDeleteException:
# TODO: log this
raise
except ImageStorageBase.ImageFileDeleteException:
except ImageFileStorageBase.ImageFileDeleteException:
# TODO: log this
raise
@ -179,26 +188,34 @@ class ImageService:
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageRecord]:
"""Gets a paginated list of image records."""
) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
try:
results = self._services.db.get_many(
image_type=image_type,
image_category=image_category,
page=page,
per_page=per_page,
results = self._services.records.get_many(
image_type,
image_category,
page,
per_page,
)
for r in results.items:
r.image_url = self._services.url.get_image_url(
image_type=image_type, image_name=r.image_name
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(image_type, r.image_name),
self._services.urls.get_thumbnail_url(image_type, r.image_name),
),
results.items,
)
)
r.thumbnail_url = self._services.url.get_thumbnail_url(
image_type=image_type, image_name=r.image_name
)
return results
return PaginatedResults[ImageDTO](
items=image_dtos,
page=results.page,
pages=results.pages,
per_page=results.per_page,
total=results.total,
)
except Exception as e:
raise e

View File

@ -1,8 +1,8 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from types import ModuleType
from invokeai.app.services.image_db import (
ImageRecordServiceBase,
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
)
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import MetadataServiceBase
@ -11,7 +11,7 @@ from invokeai.backend import ModelManager
from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase
from .image_file_storage import ImageFileStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
@ -23,13 +23,12 @@ class InvocationServices:
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
images: ImageFileStorageBase
metadata: MetadataServiceBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
configuration: InvokeAISettings
images_db: ImageRecordServiceBase
urls: UrlServiceBase
images_new: ImageService
@ -44,10 +43,9 @@ class InvocationServices:
events: EventServiceBase,
logger: ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
images: ImageFileStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
images_db: ImageRecordServiceBase,
images_new: ImageService,
urls: UrlServiceBase,
graph_library: ItemStorageABC["LibraryGraph"],
@ -63,7 +61,6 @@ class InvocationServices:
self.images = images
self.metadata = metadata
self.queue = queue
self.images_db = images_db
self.images_new = images_new
self.urls = urls
self.graph_library = graph_library

View File

@ -48,7 +48,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
return latent
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:

View File

@ -75,10 +75,10 @@ class MetadataServiceBase(ABC):
"""Builds an InvokeAIMetadata object"""
pass
@abstractmethod
def create_metadata(self, session_id: str, node_id: str) -> dict:
"""Creates metadata for a result"""
pass
# @abstractmethod
# def create_metadata(self, session_id: str, node_id: str) -> dict:
# """Creates metadata for a result"""
# pass
class PngMetadataService(MetadataServiceBase):

View File

@ -1,12 +1,11 @@
import datetime
from typing import Literal, Optional, Union
from typing import Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata,
)
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.resources import ResourceType
class ImageRecord(BaseModel):
@ -23,7 +22,27 @@ class ImageRecord(BaseModel):
metadata: Optional[
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
] = Field(default=None, description="The image's metadata.")
image_url: Optional[str] = Field(default=None, description="The URL of the image.")
thumbnail_url: Optional[str] = Field(
default=None, description="The thumbnail URL of the image."
class ImageDTO(ImageRecord):
"""Deserialized image record with URLs."""
image_url: str = Field(description="The URL of the image.")
thumbnail_url: str = Field(description="The thumbnail URL of the image.")
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
image_name=image_record.image_name,
image_type=image_record.image_type,
image_category=image_record.image_category,
created_at=image_record.created_at,
session_id=image_record.session_id,
node_id=image_record.node_id,
metadata=image_record.metadata,
image_url=image_url,
thumbnail_url=thumbnail_url,
)