feat(nodes): wip image storage implementation

This commit is contained in:
psychedelicious 2023-05-21 20:05:33 +10:00 committed by Kent Keirsey
parent d4aa79acd7
commit 1b75d899ae
13 changed files with 383 additions and 268 deletions

View File

@ -2,9 +2,8 @@
import os import os
from types import ModuleType from types import ModuleType
from invokeai.app.services.database.images.sqlite_images_db_service import ( from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
SqliteImageDb, from invokeai.app.services.images import ImageService
)
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -14,7 +13,7 @@ from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsSto
from ..services.model_manager_initializer import get_model_manager from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_storage import DiskImageStorage from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker from ..services.invoker import Invoker
@ -63,7 +62,9 @@ class ApiDependencies:
urls = LocalUrlService() urls = LocalUrlService()
images = DiskImageStorage(f"{output_folder}/images", metadata_service=metadata) image_file_storage = DiskImageFileStorage(
f"{output_folder}/images", metadata_service=metadata
)
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db") db_location = os.path.join(output_folder, "invokeai.db")
@ -72,7 +73,14 @@ class ApiDependencies:
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
) )
images_db = SqliteImageDb(filename=db_location) image_record_storage = SqliteImageRecordStorage(db_location)
images_new = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
)
# register event handler to update the `results` table when a graph execution state is inserted or updated # register event handler to update the `results` table when a graph execution state is inserted or updated
# graph_execution_manager.on_changed(results.handle_graph_execution_state_change) # graph_execution_manager.on_changed(results.handle_graph_execution_state_change)
@ -82,8 +90,8 @@ class ApiDependencies:
events=events, events=events,
latents=latents, latents=latents,
images=images, images=images,
images_new=images_new,
metadata=metadata, metadata=metadata,
images_db=images_db,
urls=urls, urls=urls,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](

View File

@ -0,0 +1,165 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
# @images_router.get("/{image_type}/{image_name}", operation_id="get_image")
# async def get_image(
# image_type: ImageType = Path(description="The type of image to get"),
# image_name: str = Path(description="The name of the image to get"),
# ) -> FileResponse:
# """Gets an image"""
# path = ApiDependencies.invoker.services.images.get_path(
# image_type=image_type, image_name=image_name
# )
# if ApiDependencies.invoker.services.images.validate_path(path):
# return FileResponse(path)
# else:
# raise HTTPException(status_code=404)
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of the image to get"),
image_name: str = Path(description="The id of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of the image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail"
)
async def get_thumbnail(
image_type: ImageType = Path(description="The type of the thumbnail to get"),
thumbnail_id: str = Path(description="The id of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=thumbnail_id, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/uploads/",
operation_id="upload_image",
responses={
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"},
},
status_code=201,
)
async def upload_image(
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
except:
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
saved_image = ApiDependencies.invoker.services.images.save(
image_type, filename, img
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = image_url
return res
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result

View File

@ -4,28 +4,28 @@ from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ImageType,
) )
from invokeai.app.services.image_db import ImageRecordServiceBase
from invokeai.app.services.image_storage import ImageStorageBase
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import ImageDTO
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
image_records_router = APIRouter(prefix="/v1/images", tags=["images", "records"]) image_records_router = APIRouter(
prefix="/v1/images/records", tags=["images", "records"]
)
@image_records_router.get("/{image_type}/{image_name}", operation_id="get_image_record") @image_records_router.get("/{image_type}/{image_name}", operation_id="get_image_record")
async def get_image_record( async def get_image_record(
image_type: ImageType = Path(description="The type of the image record to get"), image_type: ImageType = Path(description="The type of the image record to get"),
image_name: str = Path(description="The id of the image record to get"), image_name: str = Path(description="The id of the image record to get"),
) -> ImageRecord: ) -> ImageDTO:
"""Gets an image record by id""" """Gets an image record by id"""
try: try:
return ApiDependencies.invoker.services.images_new.get_record( return ApiDependencies.invoker.services.images_new.get_dto(
image_type=image_type, image_name=image_name image_type=image_type, image_name=image_name
) )
except ImageRecordServiceBase.ImageRecordNotFoundException: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -42,17 +42,17 @@ async def list_image_records(
per_page: int = Query( per_page: int = Query(
default=10, description="The number of image records per page" default=10, description="The number of image records per page"
), ),
) -> PaginatedResults[ImageRecord]: ) -> PaginatedResults[ImageDTO]:
"""Gets a list of image records by type and category""" """Gets a list of image records by type and category"""
images = ApiDependencies.invoker.services.images_new.get_many( image_dtos = ApiDependencies.invoker.services.images_new.get_many(
image_type=image_type, image_type=image_type,
image_category=image_category, image_category=image_category,
page=page, page=page,
per_page=per_page, per_page=per_page,
) )
return images return image_dtos
@image_records_router.delete("/{image_type}/{image_name}", operation_id="delete_image") @image_records_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
@ -66,9 +66,6 @@ async def delete_image_record(
ApiDependencies.invoker.services.images_new.delete( ApiDependencies.invoker.services.images_new.delete(
image_type=image_type, image_name=image_name image_type=image_type, image_name=image_name
) )
except ImageStorageBase.ImageFileDeleteException: except Exception as e:
# TODO: log this # TODO: Does this need any exception handling at all?
pass
except ImageRecordServiceBase.ImageRecordDeleteException:
# TODO: log this
pass pass

View File

@ -1,107 +1,39 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid import uuid
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from invokeai.app.api.models.images import ( from invokeai.app.models.image import (
ImageResponse, ImageCategory,
ImageResponseMetadata, ImageType,
) )
from invokeai.app.models.image import ImageType from invokeai.app.services.image_record_storage import ImageRecordStorageBase
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"]) images_router = APIRouter(prefix="/v1/images", tags=["images"])
# @images_router.get("/{image_type}/{image_name}", operation_id="get_image")
# async def get_image(
# image_type: ImageType = Path(description="The type of image to get"),
# image_name: str = Path(description="The name of the image to get"),
# ) -> FileResponse:
# """Gets an image"""
# path = ApiDependencies.invoker.services.images.get_path(
# image_type=image_type, image_name=image_name
# )
# if ApiDependencies.invoker.services.images.validate_path(path):
# return FileResponse(path)
# else:
# raise HTTPException(status_code=404)
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of the image to get"),
image_name: str = Path(description="The id of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of the image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{image_type}/thumbnails/{thumbnail_id}", operation_id="get_thumbnail"
)
async def get_thumbnail(
image_type: ImageType = Path(description="The type of the thumbnail to get"),
thumbnail_id: str = Path(description="The id of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=thumbnail_id, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post( @images_router.post(
"/uploads/", "/",
operation_id="upload_image", operation_id="upload_image",
responses={ responses={
201: { 201: {"description": "The image was uploaded successfully"},
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"}, 415: {"description": "Image upload failed"},
}, },
status_code=201, status_code=201,
) )
async def upload_image( async def upload_image(
file: UploadFile, image_type: ImageType, request: Request, response: Response file: UploadFile,
) -> ImageResponse: image_type: ImageType,
request: Request,
response: Response,
image_category: ImageCategory = ImageCategory.IMAGE,
) -> ImageRecord:
"""Uploads an image"""
if not file.content_type.startswith("image"): if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image") raise HTTPException(status_code=415, detail="Not an image")
@ -113,53 +45,33 @@ async def upload_image(
# Error opening the image # Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image") raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png" try:
image_record = ApiDependencies.invoker.services.images_new.create(
saved_image = ApiDependencies.invoker.services.images.save( image=img,
image_type, filename, img
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=image_type, image_type=image_type,
image_name=saved_image.image_name, image_category=image_category,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
) )
response.status_code = 201 response.status_code = 201
response.headers["Location"] = image_url response.headers["Location"] = image_record.image_url
return res return image_record
except Exception as e:
raise HTTPException(status_code=500)
@images_router.get(
"/", @images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
operation_id="list_images", async def delete_image_record(
responses={200: {"model": PaginatedResults[ImageResponse]}}, image_type: ImageType = Query(description="The type of image to delete"),
) image_name: str = Path(description="The name of the image to delete"),
async def list_images( ) -> None:
image_type: ImageType = Query( """Deletes an image record"""
default=ImageType.RESULT, description="The type of images to get"
), try:
page: int = Query(default=0, description="The page of images to get"), ApiDependencies.invoker.services.images_new.delete(
per_page: int = Query(default=10, description="The number of images per page"), image_type=image_type, image_name=image_name
) -> PaginatedResults[ImageResponse]: )
"""Gets a list of images""" except Exception as e:
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page) # TODO: Does this need any exception handling at all?
return result pass

View File

@ -15,7 +15,7 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.schema import schema from pydantic.schema import schema
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import image_records, images, sessions, models from .api.routers import image_files, image_records, sessions, models
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig from .services.config import InvokeAIAppConfig
@ -71,7 +71,7 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api") app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api") app.include_router(image_files.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api") app.include_router(models.models_router, prefix="/api")

View File

@ -28,7 +28,7 @@ from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
from .services.invoker import Invoker from .services.invoker import Invoker
@ -215,7 +215,7 @@ def invoke_cli():
model_manager=model_manager, model_manager=model_manager,
events=events, events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata), images=DiskImageFileStorage(f'{output_folder}/images', metadata_service=metadata),
metadata=metadata, metadata=metadata,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](

View File

@ -26,8 +26,8 @@ from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageStorageBase(ABC): class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving images.""" """Low-level service responsible for storing and retrieving image files."""
class ImageFileNotFoundException(Exception): class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage.""" """Raised when an image file is not found in storage."""
@ -75,19 +75,19 @@ class ImageStorageBase(ABC):
"""Gets the external URI to an image or its thumbnail.""" """Gets the external URI to an image or its thumbnail."""
pass pass
@abstractmethod # @abstractmethod
def get_image_location( # def get_image_location(
self, image_type: ImageType, image_name: str # self, image_type: ImageType, image_name: str
) -> str: # ) -> str:
"""Gets the location of an image.""" # """Gets the location of an image."""
pass # pass
@abstractmethod # @abstractmethod
def get_thumbnail_location( # def get_thumbnail_location(
self, image_type: ImageType, image_name: str # self, image_type: ImageType, image_name: str
) -> str: # ) -> str:
"""Gets the location of an image's thumbnail.""" # """Gets the location of an image's thumbnail."""
pass # pass
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod @abstractmethod
@ -116,7 +116,7 @@ class ImageStorageBase(ABC):
return f"{context_id}_{node_id}_{str(get_timestamp())}.png" return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageStorage(ImageStorageBase): class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk""" """Stores images on disk"""
__output_folder: str __output_folder: str
@ -206,7 +206,7 @@ class DiskImageStorage(ImageStorageBase):
self.__set_cache(image_path, image) self.__set_cache(image_path, image)
return image return image
except Exception as e: except Exception as e:
raise ImageStorageBase.ImageFileNotFoundException from e raise ImageFileStorageBase.ImageFileNotFoundException from e
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
def get_path( def get_path(
@ -282,7 +282,7 @@ class DiskImageStorage(ImageStorageBase):
created=int(os.path.getctime(image_path)), created=int(os.path.getctime(image_path)),
) )
except Exception as e: except Exception as e:
raise ImageStorageBase.ImageFileSaveException from e raise ImageFileStorageBase.ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_type: ImageType, image_name: str) -> None:
try: try:
@ -302,7 +302,7 @@ class DiskImageStorage(ImageStorageBase):
if thumbnail_path in self.__cache: if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path] del self.__cache[thumbnail_path]
except Exception as e: except Exception as e:
raise ImageStorageBase.ImageFileDeleteException from e raise ImageFileStorageBase.ImageFileDeleteException from e
def __get_cache(self, image_name: str) -> Image | None: def __get_cache(self, image_name: str) -> Image | None:
return None if image_name not in self.__cache else self.__cache[image_name] return None if image_name not in self.__cache else self.__cache[image_name]

View File

@ -26,7 +26,7 @@ from invokeai.app.services.util.deserialize_image_record import (
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
class ImageRecordServiceBase(ABC): class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store.""" """Low-level service responsible for interfacing with the image record store."""
class ImageRecordNotFoundException(Exception): class ImageRecordNotFoundException(Exception):
@ -85,7 +85,7 @@ class ImageRecordServiceBase(ABC):
pass pass
class SqliteImageRecordService(ImageRecordServiceBase): class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str _filename: str
_conn: sqlite3.Connection _conn: sqlite3.Connection
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
@ -277,7 +277,7 @@ class SqliteImageRecordService(ImageRecordServiceBase):
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise ImageRecordServiceBase.ImageRecordDeleteException from e raise ImageRecordStorageBase.ImageRecordDeleteException from e
finally: finally:
self._lock.release() self._lock.release()
@ -324,6 +324,6 @@ class SqliteImageRecordService(ImageRecordServiceBase):
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise ImageRecordServiceBase.ImageRecordNotFoundException from e raise ImageRecordStorageBase.ImageRecordNotFoundException from e
finally: finally:
self._lock.release() self._lock.release()

View File

@ -1,4 +1,4 @@
from typing import Union from typing import Optional, Union
import uuid import uuid
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ImageType
@ -6,11 +6,15 @@ from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata, GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata,
) )
from invokeai.app.services.image_db import ( from invokeai.app.services.image_record_storage import (
ImageRecordServiceBase, ImageRecordStorageBase,
) )
from invokeai.app.services.models.image_record import ImageRecord from invokeai.app.services.models.image_record import (
from invokeai.app.services.image_storage import ImageStorageBase ImageRecord,
ImageDTO,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
@ -20,22 +24,22 @@ from invokeai.app.util.misc import get_iso_timestamp
class ImageServiceDependencies: class ImageServiceDependencies:
"""Service dependencies for the ImageManagementService.""" """Service dependencies for the ImageManagementService."""
db: ImageRecordServiceBase records: ImageRecordStorageBase
storage: ImageStorageBase files: ImageFileStorageBase
metadata: MetadataServiceBase metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
def __init__( def __init__(
self, self,
image_db_service: ImageRecordServiceBase, image_record_storage: ImageRecordStorageBase,
image_storage_service: ImageStorageBase, image_file_storage: ImageFileStorageBase,
image_metadata_service: MetadataServiceBase, metadata: MetadataServiceBase,
url_service: UrlServiceBase, url: UrlServiceBase,
): ):
self.db = image_db_service self.records = image_record_storage
self.storage = image_storage_service self.files = image_file_storage
self.metadata = image_metadata_service self.metadata = metadata
self.url = url_service self.urls = url
class ImageService: class ImageService:
@ -45,24 +49,24 @@ class ImageService:
def __init__( def __init__(
self, self,
image_db_service: ImageRecordServiceBase, image_record_storage: ImageRecordStorageBase,
image_storage_service: ImageStorageBase, image_file_storage: ImageFileStorageBase,
image_metadata_service: MetadataServiceBase, metadata: MetadataServiceBase,
url_service: UrlServiceBase, url: UrlServiceBase,
): ):
self._services = ImageServiceDependencies( self._services = ImageServiceDependencies(
image_db_service=image_db_service, image_record_storage=image_record_storage,
image_storage_service=image_storage_service, image_file_storage=image_file_storage,
image_metadata_service=image_metadata_service, metadata=metadata,
url_service=url_service, url=url,
) )
def _create_image_name( def _create_image_name(
self, self,
image_type: ImageType, image_type: ImageType,
image_category: ImageCategory, image_category: ImageCategory,
node_id: Union[str, None], node_id: Optional[str] = None,
session_id: Union[str, None], session_id: Optional[str] = None,
) -> str: ) -> str:
"""Creates an image name.""" """Creates an image name."""
uuid_str = str(uuid.uuid4()) uuid_str = str(uuid.uuid4())
@ -77,12 +81,12 @@ class ImageService:
image: PILImageType, image: PILImageType,
image_type: ImageType, image_type: ImageType,
image_category: ImageCategory, image_category: ImageCategory,
node_id: Union[str, None], node_id: Optional[str] = None,
session_id: Union[str, None], session_id: Optional[str] = None,
metadata: Union[ metadata: Optional[
GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata, None Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
], ] = None,
) -> ImageRecord: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
image_name = self._create_image_name( image_name = self._create_image_name(
image_type=image_type, image_type=image_type,
@ -95,14 +99,14 @@ class ImageService:
try: try:
# TODO: Consider using a transaction here to ensure consistency between storage and database # TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.storage.save( self._services.files.save(
image_type=image_type, image_type=image_type,
image_name=image_name, image_name=image_name,
image=image, image=image,
metadata=metadata, metadata=metadata,
) )
self._services.db.save( self._services.records.save(
image_name=image_name, image_name=image_name,
image_type=image_type, image_type=image_type,
image_category=image_category, image_category=image_category,
@ -112,15 +116,10 @@ class ImageService:
created_at=timestamp, created_at=timestamp,
) )
image_url = self._services.url.get_image_url( image_url = self._services.urls.get_image_url(image_type, image_name)
image_type=image_type, image_name=image_name thumbnail_url = self._services.urls.get_thumbnail_url(image_type, image_name)
)
thumbnail_url = self._services.url.get_thumbnail_url( return ImageDTO(
image_type=image_type, image_name=image_name
)
return ImageRecord(
image_name=image_name, image_name=image_name,
image_type=image_type, image_type=image_type,
image_category=image_category, image_category=image_category,
@ -131,32 +130,42 @@ class ImageService:
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
) )
except ImageRecordServiceBase.ImageRecordSaveException: except ImageRecordStorageBase.ImageRecordSaveException:
# TODO: log this # TODO: log this
raise raise
except ImageStorageBase.ImageFileSaveException: except ImageFileStorageBase.ImageFileSaveException:
# TODO: log this # TODO: log this
raise raise
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image.""" """Gets an image as a PIL image."""
try: try:
pil_image = self._services.storage.get( return self._services.files.get(image_type, image_name)
image_type=image_type, image_name=image_name except ImageFileStorageBase.ImageFileNotFoundException:
)
return pil_image
except ImageStorageBase.ImageFileNotFoundException:
# TODO: log this # TODO: log this
raise raise
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
try: try:
image_record = self._services.db.get( return self._services.records.get(image_type, image_name)
image_type=image_type, image_name=image_name except ImageRecordStorageBase.ImageRecordNotFoundException:
# TODO: log this
raise
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
try:
image_record = self._services.records.get(image_type, image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_type, image_name),
self._services.urls.get_thumbnail_url(image_type, image_name),
) )
return image_record
except ImageRecordServiceBase.ImageRecordNotFoundException: return image_dto
except ImageRecordStorageBase.ImageRecordNotFoundException:
# TODO: log this # TODO: log this
raise raise
@ -164,12 +173,12 @@ class ImageService:
"""Deletes an image.""" """Deletes an image."""
# TODO: Consider using a transaction here to ensure consistency between storage and database # TODO: Consider using a transaction here to ensure consistency between storage and database
try: try:
self._services.storage.delete(image_type=image_type, image_name=image_name) self._services.files.delete(image_type, image_name)
self._services.db.delete(image_type=image_type, image_name=image_name) self._services.records.delete(image_type, image_name)
except ImageRecordServiceBase.ImageRecordDeleteException: except ImageRecordStorageBase.ImageRecordDeleteException:
# TODO: log this # TODO: log this
raise raise
except ImageStorageBase.ImageFileDeleteException: except ImageFileStorageBase.ImageFileDeleteException:
# TODO: log this # TODO: log this
raise raise
@ -179,26 +188,34 @@ class ImageService:
image_category: ImageCategory, image_category: ImageCategory,
page: int = 0, page: int = 0,
per_page: int = 10, per_page: int = 10,
) -> PaginatedResults[ImageRecord]: ) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image records.""" """Gets a paginated list of image DTOs."""
try: try:
results = self._services.db.get_many( results = self._services.records.get_many(
image_type=image_type, image_type,
image_category=image_category, image_category,
page=page, page,
per_page=per_page, per_page,
) )
for r in results.items: image_dtos = list(
r.image_url = self._services.url.get_image_url( map(
image_type=image_type, image_name=r.image_name lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(image_type, r.image_name),
self._services.urls.get_thumbnail_url(image_type, r.image_name),
),
results.items,
)
) )
r.thumbnail_url = self._services.url.get_thumbnail_url( return PaginatedResults[ImageDTO](
image_type=image_type, image_name=r.image_name items=image_dtos,
page=results.page,
pages=results.pages,
per_page=results.per_page,
total=results.total,
) )
return results
except Exception as e: except Exception as e:
raise e raise e

View File

@ -1,8 +1,8 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from types import ModuleType from types import ModuleType
from invokeai.app.services.image_db import ( from invokeai.app.services.image_record_storage import (
ImageRecordServiceBase, ImageRecordStorageBase,
) )
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
@ -11,7 +11,7 @@ from invokeai.backend import ModelManager
from .events import EventServiceBase from .events import EventServiceBase
from .latent_storage import LatentsStorageBase from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase from .image_file_storage import ImageFileStorageBase
from .restoration_services import RestorationServices from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
@ -23,13 +23,12 @@ class InvocationServices:
events: EventServiceBase events: EventServiceBase
latents: LatentsStorageBase latents: LatentsStorageBase
images: ImageStorageBase images: ImageFileStorageBase
metadata: MetadataServiceBase metadata: MetadataServiceBase
queue: InvocationQueueABC queue: InvocationQueueABC
model_manager: ModelManager model_manager: ModelManager
restoration: RestorationServices restoration: RestorationServices
configuration: InvokeAISettings configuration: InvokeAISettings
images_db: ImageRecordServiceBase
urls: UrlServiceBase urls: UrlServiceBase
images_new: ImageService images_new: ImageService
@ -44,10 +43,9 @@ class InvocationServices:
events: EventServiceBase, events: EventServiceBase,
logger: ModuleType, logger: ModuleType,
latents: LatentsStorageBase, latents: LatentsStorageBase,
images: ImageStorageBase, images: ImageFileStorageBase,
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
queue: InvocationQueueABC, queue: InvocationQueueABC,
images_db: ImageRecordServiceBase,
images_new: ImageService, images_new: ImageService,
urls: UrlServiceBase, urls: UrlServiceBase,
graph_library: ItemStorageABC["LibraryGraph"], graph_library: ItemStorageABC["LibraryGraph"],
@ -63,7 +61,6 @@ class InvocationServices:
self.images = images self.images = images
self.metadata = metadata self.metadata = metadata
self.queue = queue self.queue = queue
self.images_db = images_db
self.images_new = images_new self.images_new = images_new
self.urls = urls self.urls = urls
self.graph_library = graph_library self.graph_library = graph_library

View File

@ -48,7 +48,7 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
return latent return latent
def save(self, name: str, data: torch.Tensor) -> None: def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data) self.__underlying_storage.save(name, data)
self.__set_cache(name, data) self.__set_cache(name, data)
def delete(self, name: str) -> None: def delete(self, name: str) -> None:

View File

@ -75,10 +75,10 @@ class MetadataServiceBase(ABC):
"""Builds an InvokeAIMetadata object""" """Builds an InvokeAIMetadata object"""
pass pass
@abstractmethod # @abstractmethod
def create_metadata(self, session_id: str, node_id: str) -> dict: # def create_metadata(self, session_id: str, node_id: str) -> dict:
"""Creates metadata for a result""" # """Creates metadata for a result"""
pass # pass
class PngMetadataService(MetadataServiceBase): class PngMetadataService(MetadataServiceBase):

View File

@ -1,12 +1,11 @@
import datetime import datetime
from typing import Literal, Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.metadata import ( from invokeai.app.models.metadata import (
GeneratedImageOrLatentsMetadata, GeneratedImageOrLatentsMetadata,
UploadedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata,
) )
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.resources import ResourceType
class ImageRecord(BaseModel): class ImageRecord(BaseModel):
@ -23,7 +22,27 @@ class ImageRecord(BaseModel):
metadata: Optional[ metadata: Optional[
Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata] Union[GeneratedImageOrLatentsMetadata, UploadedImageOrLatentsMetadata]
] = Field(default=None, description="The image's metadata.") ] = Field(default=None, description="The image's metadata.")
image_url: Optional[str] = Field(default=None, description="The URL of the image.")
thumbnail_url: Optional[str] = Field(
default=None, description="The thumbnail URL of the image." class ImageDTO(ImageRecord):
"""Deserialized image record with URLs."""
image_url: str = Field(description="The URL of the image.")
thumbnail_url: str = Field(description="The thumbnail URL of the image.")
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
image_name=image_record.image_name,
image_type=image_record.image_type,
image_category=image_record.image_category,
created_at=image_record.created_at,
session_id=image_record.session_id,
node_id=image_record.node_id,
metadata=image_record.metadata,
image_url=image_url,
thumbnail_url=thumbnail_url,
) )