feat(db, api): update get_board_for_image & service dependencies

- previously was `get_boards_for_image`, returning a list of `BoardDTO`, now returns a single `board_id`
This commit is contained in:
psychedelicious 2023-06-16 15:52:32 +10:00
parent 70cc037a9c
commit d604d986f9
5 changed files with 69 additions and 132 deletions

View File

@ -12,7 +12,7 @@ from invokeai.app.services.board_images import (
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
@ -106,13 +106,16 @@ class ApiDependencies:
)
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
services=ImageServiceDependencies(
board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
)
services = InvocationServices(

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
import sqlite3
import threading
from typing import cast
from typing import Union, cast
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
@ -41,11 +41,11 @@ class BoardImageRecordStorageBase(ABC):
pass
@abstractmethod
def get_boards_for_image(
def get_board_for_image(
self,
board_id: str,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets boards for an image."""
image_name: str,
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
@abstractmethod
@ -134,7 +134,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
try:
self._lock.acquire()
self._cursor.execute(
@ -156,7 +155,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
try:
self._lock.acquire()
self._cursor.execute(
@ -179,7 +177,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
try:
self._lock.acquire()
self._cursor.execute(
@ -211,46 +208,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
items=images, offset=offset, limit=limit, total=count
)
def get_boards_for_image(
def get_board_for_image(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets boards for an image."""
image_name: str,
) -> Union[str, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT boards.*
SELECT board_id
FROM board_images
INNER JOIN boards ON board_images.board_id = boards.board_id
WHERE board_images.image_name = ?
ORDER BY board_images.updated_at DESC;
WHERE image_name = ?;
""",
(board_id,),
(image_name,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: BoardRecord(**r), result))
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM boards WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=boards, offset=offset, limit=limit, total=count
)
def get_image_count_for_board(self, board_id: str) -> int:
"""Gets the number of images for a board."""
try:
self._lock.acquire()
self._cursor.execute(

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import Union
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import (
BoardRecord,
@ -45,11 +46,11 @@ class BoardImagesServiceABC(ABC):
pass
@abstractmethod
def get_boards_for_image(
def get_board_for_image(
self,
image_name: str,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets boards for an image."""
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
@ -110,6 +111,7 @@ class BoardImagesService(BoardImagesServiceABC):
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
board_id,
),
image_records.items,
)
@ -121,38 +123,12 @@ class BoardImagesService(BoardImagesServiceABC):
total=image_records.total,
)
def get_boards_for_image(
def get_board_for_image(
self,
image_name: str,
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_image_records.get_boards_for_image(
image_name
)
board_dtos = []
for r in board_records.items:
cover_image_url = (
self._services.urls.get_image_url(r.cover_image_name, True)
if r.cover_image_name
else None
)
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(
board_record_to_dto(
r,
cover_image_url,
image_count,
)
)
return OffsetPaginatedResults[BoardDTO](
items=board_dtos,
offset=board_records.offset,
limit=board_records.limit,
total=board_records.total,
)
) -> Union[str, None]:
board_id = self._services.board_image_records.get_board_for_image(image_name)
return board_id
def board_record_to_dto(

View File

@ -10,6 +10,7 @@ from invokeai.app.models.image import (
InvalidOriginException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
@ -114,8 +115,9 @@ class ImageServiceABC(ABC):
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
records: ImageRecordStorageBase
files: ImageFileStorageBase
image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
@ -126,14 +128,16 @@ class ImageServiceDependencies:
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.records = image_record_storage
self.files = image_file_storage
self.image_records = image_record_storage
self.image_files = image_file_storage
self.board_image_records = board_image_record_storage
self.metadata = metadata
self.urls = url
self.logger = logger
@ -144,25 +148,8 @@ class ImageServiceDependencies:
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
def __init__(self, services: ImageServiceDependencies):
self._services = services
def create(
self,
@ -187,7 +174,7 @@ class ImageService(ImageServiceABC):
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save(
self._services.image_records.save(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
@ -202,35 +189,15 @@ class ImageService(ImageServiceABC):
metadata=metadata,
)
self._services.files.save(
self._services.image_files.save(
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_name)
thumbnail_url = self._services.urls.get_image_url(image_name, True)
image_dto = self.get_dto(image_name)
return ImageDTO(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
# Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
@ -247,7 +214,7 @@ class ImageService(ImageServiceABC):
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, changes)
self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
@ -258,7 +225,7 @@ class ImageService(ImageServiceABC):
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_name)
return self._services.image_files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
@ -268,7 +235,7 @@ class ImageService(ImageServiceABC):
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_name)
return self._services.image_records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
@ -278,12 +245,13 @@ class ImageService(ImageServiceABC):
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_name)
image_record = self._services.image_records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
self._services.board_image_records.get_board_for_image(image_name),
)
return image_dto
@ -296,14 +264,14 @@ class ImageService(ImageServiceABC):
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.files.get_path(image_name, thumbnail)
return self._services.image_files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.files.validate_path(path)
return self._services.image_files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
@ -324,7 +292,7 @@ class ImageService(ImageServiceABC):
is_intermediate: Optional[bool] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
results = self._services.image_records.get_many(
offset,
limit,
image_origin,
@ -338,6 +306,9 @@ class ImageService(ImageServiceABC):
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(
r.image_name
),
),
results.items,
)
@ -355,8 +326,8 @@ class ImageService(ImageServiceABC):
def delete(self, image_name: str):
try:
self._services.files.delete(image_name)
self._services.records.delete(image_name)
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise

View File

@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend with URLs."""
"""Deserialized image record, enriched for the frontend."""
board_id: Union[str, None] = Field(
description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
pass
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
**image_record.dict(),
image_url=image_url,
thumbnail_url=thumbnail_url,
board_id=board_id,
)