feat(db, api): update get_board_for_image & service dependencies

- previously was `get_boards_for_image`, returning a list of `BoardDTO`, now returns a single `board_id`
This commit is contained in:
psychedelicious 2023-06-16 15:52:32 +10:00
parent 70cc037a9c
commit d604d986f9
5 changed files with 69 additions and 132 deletions

View File

@ -12,7 +12,7 @@ from invokeai.app.services.board_images import (
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
@ -106,13 +106,16 @@ class ApiDependencies:
) )
images = ImageService( images = ImageService(
image_record_storage=image_record_storage, services=ImageServiceDependencies(
image_file_storage=image_file_storage, board_image_record_storage=board_image_record_storage,
metadata=metadata, image_record_storage=image_record_storage,
url=urls, image_file_storage=image_file_storage,
logger=logger, metadata=metadata,
names=names, url=urls,
graph_execution_manager=graph_execution_manager, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
) )
services = InvocationServices( services = InvocationServices(

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sqlite3 import sqlite3
import threading import threading
from typing import cast from typing import Union, cast
from invokeai.app.services.board_record_storage import BoardRecord from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.image_record_storage import OffsetPaginatedResults
@ -41,11 +41,11 @@ class BoardImageRecordStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def get_boards_for_image( def get_board_for_image(
self, self,
board_id: str, image_name: str,
) -> OffsetPaginatedResults[BoardRecord]: ) -> Union[str, None]:
"""Gets boards for an image.""" """Gets an image's board id, if it has one."""
pass pass
@abstractmethod @abstractmethod
@ -134,7 +134,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
board_id: str, board_id: str,
image_name: str, image_name: str,
) -> None: ) -> None:
"""Adds an image to a board."""
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -156,7 +155,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
board_id: str, board_id: str,
image_name: str, image_name: str,
) -> None: ) -> None:
"""Removes an image from a board."""
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -179,7 +177,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0, offset: int = 0,
limit: int = 10, limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]: ) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -211,46 +208,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
items=images, offset=offset, limit=limit, total=count items=images, offset=offset, limit=limit, total=count
) )
def get_boards_for_image( def get_board_for_image(
self, self,
board_id: str, image_name: str,
offset: int = 0, ) -> Union[str, None]:
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets boards for an image."""
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT boards.* SELECT board_id
FROM board_images FROM board_images
INNER JOIN boards ON board_images.board_id = boards.board_id WHERE image_name = ?;
WHERE board_images.image_name = ?
ORDER BY board_images.updated_at DESC;
""", """,
(board_id,), (image_name,),
) )
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = self._cursor.fetchone()
boards = list(map(lambda r: BoardRecord(**r), result)) if result is None:
return None
self._cursor.execute( return cast(str, result[0])
"""--sql
SELECT COUNT(*) FROM boards WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise e raise e
finally: finally:
self._lock.release() self._lock.release()
return OffsetPaginatedResults(
items=boards, offset=offset, limit=limit, total=count
)
def get_image_count_for_board(self, board_id: str) -> int: def get_image_count_for_board(self, board_id: str) -> int:
"""Gets the number of images for a board."""
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from typing import Union
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import ( from invokeai.app.services.board_record_storage import (
BoardRecord, BoardRecord,
@ -45,11 +46,11 @@ class BoardImagesServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_boards_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> OffsetPaginatedResults[BoardDTO]: ) -> Union[str, None]:
"""Gets boards for an image.""" """Gets an image's board id, if it has one."""
pass pass
@ -110,6 +111,7 @@ class BoardImagesService(BoardImagesServiceABC):
r, r,
self._services.urls.get_image_url(r.image_name), self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True), self._services.urls.get_image_url(r.image_name, True),
board_id,
), ),
image_records.items, image_records.items,
) )
@ -121,38 +123,12 @@ class BoardImagesService(BoardImagesServiceABC):
total=image_records.total, total=image_records.total,
) )
def get_boards_for_image( def get_board_for_image(
self, self,
image_name: str, image_name: str,
) -> OffsetPaginatedResults[BoardDTO]: ) -> Union[str, None]:
board_records = self._services.board_image_records.get_boards_for_image( board_id = self._services.board_image_records.get_board_for_image(image_name)
image_name return board_id
)
board_dtos = []
for r in board_records.items:
cover_image_url = (
self._services.urls.get_image_url(r.cover_image_name, True)
if r.cover_image_name
else None
)
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(
board_record_to_dto(
r,
cover_image_url,
image_count,
)
)
return OffsetPaginatedResults[BoardDTO](
items=board_dtos,
offset=board_records.offset,
limit=board_records.limit,
total=board_records.total,
)
def board_record_to_dto( def board_record_to_dto(

View File

@ -10,6 +10,7 @@ from invokeai.app.models.image import (
InvalidOriginException, InvalidOriginException,
) )
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_record_storage import ( from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException, ImageRecordDeleteException,
ImageRecordNotFoundException, ImageRecordNotFoundException,
@ -114,8 +115,9 @@ class ImageServiceABC(ABC):
class ImageServiceDependencies: class ImageServiceDependencies:
"""Service dependencies for the ImageService.""" """Service dependencies for the ImageService."""
records: ImageRecordStorageBase image_records: ImageRecordStorageBase
files: ImageFileStorageBase image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
logger: Logger logger: Logger
@ -126,14 +128,16 @@ class ImageServiceDependencies:
self, self,
image_record_storage: ImageRecordStorageBase, image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase, image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase, names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self.records = image_record_storage self.image_records = image_record_storage
self.files = image_file_storage self.image_files = image_file_storage
self.board_image_records = board_image_record_storage
self.metadata = metadata self.metadata = metadata
self.urls = url self.urls = url
self.logger = logger self.logger = logger
@ -144,25 +148,8 @@ class ImageServiceDependencies:
class ImageService(ImageServiceABC): class ImageService(ImageServiceABC):
_services: ImageServiceDependencies _services: ImageServiceDependencies
def __init__( def __init__(self, services: ImageServiceDependencies):
self, self._services = services
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
def create( def create(
self, self,
@ -187,7 +174,7 @@ class ImageService(ImageServiceABC):
try: try:
# TODO: Consider using a transaction here to ensure consistency between storage and database # TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save( self._services.image_records.save(
# Non-nullable fields # Non-nullable fields
image_name=image_name, image_name=image_name,
image_origin=image_origin, image_origin=image_origin,
@ -202,35 +189,15 @@ class ImageService(ImageServiceABC):
metadata=metadata, metadata=metadata,
) )
self._services.files.save( self._services.image_files.save(
image_name=image_name, image_name=image_name,
image=image, image=image,
metadata=metadata, metadata=metadata,
) )
image_url = self._services.urls.get_image_url(image_name) image_dto = self.get_dto(image_name)
thumbnail_url = self._services.urls.get_image_url(image_name, True)
return ImageDTO( return image_dto
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
# Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to save image record") self._services.logger.error("Failed to save image record")
raise raise
@ -247,7 +214,7 @@ class ImageService(ImageServiceABC):
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
try: try:
self._services.records.update(image_name, changes) self._services.image_records.update(image_name, changes)
return self.get_dto(image_name) return self.get_dto(image_name)
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to update image record") self._services.logger.error("Failed to update image record")
@ -258,7 +225,7 @@ class ImageService(ImageServiceABC):
def get_pil_image(self, image_name: str) -> PILImageType: def get_pil_image(self, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_name) return self._services.image_files.get(image_name)
except ImageFileNotFoundException: except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self._services.logger.error("Failed to get image file")
raise raise
@ -268,7 +235,7 @@ class ImageService(ImageServiceABC):
def get_record(self, image_name: str) -> ImageRecord: def get_record(self, image_name: str) -> ImageRecord:
try: try:
return self._services.records.get(image_name) return self._services.image_records.get(image_name)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")
raise raise
@ -278,12 +245,13 @@ class ImageService(ImageServiceABC):
def get_dto(self, image_name: str) -> ImageDTO: def get_dto(self, image_name: str) -> ImageDTO:
try: try:
image_record = self._services.records.get(image_name) image_record = self._services.image_records.get(image_name)
image_dto = image_record_to_dto( image_dto = image_record_to_dto(
image_record, image_record,
self._services.urls.get_image_url(image_name), self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True), self._services.urls.get_image_url(image_name, True),
self._services.board_image_records.get_board_for_image(image_name),
) )
return image_dto return image_dto
@ -296,14 +264,14 @@ class ImageService(ImageServiceABC):
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
return self._services.files.get_path(image_name, thumbnail) return self._services.image_files.get_path(image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
def validate_path(self, path: str) -> bool: def validate_path(self, path: str) -> bool:
try: try:
return self._services.files.validate_path(path) return self._services.image_files.validate_path(path)
except Exception as e: except Exception as e:
self._services.logger.error("Problem validating image path") self._services.logger.error("Problem validating image path")
raise e raise e
@ -324,7 +292,7 @@ class ImageService(ImageServiceABC):
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
try: try:
results = self._services.records.get_many( results = self._services.image_records.get_many(
offset, offset,
limit, limit,
image_origin, image_origin,
@ -338,6 +306,9 @@ class ImageService(ImageServiceABC):
r, r,
self._services.urls.get_image_url(r.image_name), self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True), self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(
r.image_name
),
), ),
results.items, results.items,
) )
@ -355,8 +326,8 @@ class ImageService(ImageServiceABC):
def delete(self, image_name: str): def delete(self, image_name: str):
try: try:
self._services.files.delete(image_name) self._services.image_files.delete(image_name)
self._services.records.delete(image_name) self._services.image_records.delete(image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record") self._services.logger.error(f"Failed to delete image record")
raise raise

View File

@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO): class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend with URLs.""" """Deserialized image record, enriched for the frontend."""
board_id: Union[str, None] = Field(
description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
pass pass
def image_record_to_dto( def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
) -> ImageDTO: ) -> ImageDTO:
"""Converts an image record to an image DTO.""" """Converts an image record to an image DTO."""
return ImageDTO( return ImageDTO(
**image_record.dict(), **image_record.dict(),
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
board_id=board_id,
) )