[WIP] board list endpoint w cover photos

This commit is contained in:
maryhipp 2023-06-13 14:08:04 -07:00 committed by psychedelicious
parent 4bfaae6617
commit 3833304f57
3 changed files with 156 additions and 0 deletions

View File

@ -1,5 +1,7 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.boards import BoardRecord, BoardRecordChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from ..dependencies import ApiDependencies
@ -38,3 +40,52 @@ async def delete_board(
pass
@boards_router.get(
"/",
operation_id="list_boards",
response_model=OffsetPaginatedResults[BoardRecord],
)
async def list_boards(
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of boards per page"),
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets a list of boards"""
results = ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
boards = list(
map(
lambda r: board_record_to_dto(
r,
generate_cover_photo_url(r.id)
),
results.boards,
)
)
return boards
class BoardDTO(BaseModel):
"""A DTO for an image"""
id: str
name: str
cover_image_url: str
def board_record_to_dto(
board_record: BoardRecord, cover_image_url: str
) -> BoardDTO:
"""Converts an image record to an image DTO."""
return BoardDTO(
**board_record.dict(),
cover_image_url=cover_image_url,
)
def generate_cover_photo_url(board_id: str) -> str | None:
cover_photo = ApiDependencies.invoker.services.images._services.records.get_board_cover_photo(board_id)
if cover_photo is not None:
url = ApiDependencies.invoker.services.images._services.urls.get_image_url(cover_photo.image_origin, cover_photo.image_name)
return url

View File

@ -26,6 +26,23 @@ class BoardRecord(BaseModel):
description="The updated timestamp of the board."
)
class BoardRecordInList(BaseModel):
"""Deserialized board record in a list."""
id: str = Field(description="The unique ID of the board.")
name: str = Field(description="The name of the board.")
most_recent_image_url: Optional[str] = Field(
description="The URL of the most recent image in the board."
)
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
class BoardRecordChanges(BaseModel, extra=Extra.forbid):
name: Optional[str] = Field(
description="The board's new name."
@ -67,6 +84,18 @@ class BoardStorageBase(ABC):
"""Saves a board record."""
pass
def get_cover_photo(self, board_id: str) -> Optional[str]:
"""Gets the cover photo for a board."""
pass
def get_many(
self,
offset: int,
limit: int,
):
"""Gets many board records."""
pass
class SqliteBoardStorage(BoardStorageBase):
_filename: str
@ -177,3 +206,48 @@ class SqliteBoardStorage(BoardStorageBase):
raise BoardRecordSaveException from e
finally:
self._lock.release()
def get_many(
self,
offset: int,
limit: int,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
query_conditions = ""
query_params = []
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
images_params.append(limit)
images_params.append(offset)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [BoardRecord(**dict(row)) for row in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=boards, offset=offset, limit=limit, total=count
)

View File

@ -94,6 +94,11 @@ class ImageRecordStorageBase(ABC):
"""Deletes an image record."""
pass
@abstractmethod
def get_board_cover_photo(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the cover photo for a board."""
pass
@abstractmethod
def save(
self,
@ -280,6 +285,32 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally:
self._lock.release()
def get_board_cover_photo(self, board_id: str) -> ImageRecord | None:
try:
self._lock.acquire()
self._cursor.execute(
"""
SELECT *
FROM images
WHERE board_id = ?
ORDER BY created_at DESC
LIMIT 1
""",
(board_id),
)
self._conn.commit()
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise ImageRecordNotFoundException
return deserialize_image_record(dict(result))
def get_many(
self,
offset: int = 0,