mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(db): add get_all()
method for boards
This is needed to show the full list of boards in the update boards modal.
This commit is contained in:
parent
9ef64016c7
commit
661a94b3de
@ -1,3 +1,4 @@
|
|||||||
|
from typing import Optional, Union
|
||||||
from fastapi import Body, HTTPException, Path, Query
|
from fastapi import Body, HTTPException, Path, Query
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from invokeai.app.services.board_record_storage import BoardChanges
|
from invokeai.app.services.board_record_storage import BoardChanges
|
||||||
@ -19,7 +20,7 @@ boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
|
|||||||
response_model=BoardDTO,
|
response_model=BoardDTO,
|
||||||
)
|
)
|
||||||
async def create_board(
|
async def create_board(
|
||||||
board_name: str = Body(description="The name of the board to create"),
|
board_name: str = Query(description="The name of the board to create"),
|
||||||
) -> BoardDTO:
|
) -> BoardDTO:
|
||||||
"""Creates a board"""
|
"""Creates a board"""
|
||||||
try:
|
try:
|
||||||
@ -70,16 +71,25 @@ async def delete_board(
|
|||||||
@boards_router.get(
|
@boards_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_boards",
|
operation_id="list_boards",
|
||||||
response_model=OffsetPaginatedResults[BoardDTO],
|
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
|
||||||
)
|
)
|
||||||
async def list_boards(
|
async def list_boards(
|
||||||
offset: int = Query(default=0, description="The page offset"),
|
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||||
limit: int = Query(default=10, description="The number of boards per page"),
|
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||||
) -> OffsetPaginatedResults[BoardDTO]:
|
limit: Optional[int] = Query(
|
||||||
|
default=None, description="The number of boards per page"
|
||||||
|
),
|
||||||
|
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||||
"""Gets a list of boards"""
|
"""Gets a list of boards"""
|
||||||
|
if all:
|
||||||
results = ApiDependencies.invoker.services.boards.get_many(
|
return ApiDependencies.invoker.services.boards.get_all()
|
||||||
offset,
|
elif offset is not None and limit is not None:
|
||||||
limit,
|
return ApiDependencies.invoker.services.boards.get_many(
|
||||||
)
|
offset,
|
||||||
return results
|
limit,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
|
||||||
|
)
|
||||||
|
@ -5,12 +5,14 @@ import threading
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import uuid
|
import uuid
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
from invokeai.app.services.models.board_record import (
|
||||||
|
BoardRecord,
|
||||||
|
deserialize_board_record,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, Extra
|
from pydantic import BaseModel, Field, Extra
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||||
board_name: Optional[str] = Field(description="The board's new name.")
|
board_name: Optional[str] = Field(description="The board's new name.")
|
||||||
cover_image_name: Optional[str] = Field(
|
cover_image_name: Optional[str] = Field(
|
||||||
@ -81,6 +83,13 @@ class BoardRecordStorageBase(ABC):
|
|||||||
"""Gets many board records."""
|
"""Gets many board records."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardRecord]:
|
||||||
|
"""Gets all board records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||||
_filename: str
|
_filename: str
|
||||||
@ -292,3 +301,29 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardRecord]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
# Get all the boards
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM boards
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||||
|
|
||||||
|
return boards
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
@ -61,6 +61,13 @@ class BoardServiceABC(ABC):
|
|||||||
"""Gets many boards."""
|
"""Gets many boards."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardDTO]:
|
||||||
|
"""Gets all boards."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BoardServiceDependencies:
|
class BoardServiceDependencies:
|
||||||
"""Service dependencies for the BoardService."""
|
"""Service dependencies for the BoardService."""
|
||||||
@ -101,8 +108,10 @@ class BoardService(BoardServiceABC):
|
|||||||
|
|
||||||
def get_dto(self, board_id: str) -> BoardDTO:
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
board_record = self._services.board_records.get(board_id)
|
board_record = self._services.board_records.get(board_id)
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_recordboard_id)
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
if (cover_image):
|
board_record.board_id
|
||||||
|
)
|
||||||
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
@ -117,12 +126,14 @@ class BoardService(BoardServiceABC):
|
|||||||
changes: BoardChanges,
|
changes: BoardChanges,
|
||||||
) -> BoardDTO:
|
) -> BoardDTO:
|
||||||
board_record = self._services.board_records.update(board_id, changes)
|
board_record = self._services.board_records.update(board_id, changes)
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
if (cover_image):
|
board_record.board_id
|
||||||
|
)
|
||||||
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
board_id
|
board_id
|
||||||
)
|
)
|
||||||
@ -137,8 +148,10 @@ class BoardService(BoardServiceABC):
|
|||||||
board_records = self._services.board_records.get_many(offset, limit)
|
board_records = self._services.board_records.get_many(offset, limit)
|
||||||
board_dtos = []
|
board_dtos = []
|
||||||
for r in board_records.items:
|
for r in board_records.items:
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
if (cover_image):
|
r.board_id
|
||||||
|
)
|
||||||
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
@ -151,3 +164,22 @@ class BoardService(BoardServiceABC):
|
|||||||
return OffsetPaginatedResults[BoardDTO](
|
return OffsetPaginatedResults[BoardDTO](
|
||||||
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_all(self) -> list[BoardDTO]:
|
||||||
|
board_records = self._services.board_records.get_all()
|
||||||
|
board_dtos = []
|
||||||
|
for r in board_records:
|
||||||
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
|
return board_dtos
|
@ -3,6 +3,7 @@ from datetime import datetime
|
|||||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
|
|
||||||
class BoardRecord(BaseModel):
|
class BoardRecord(BaseModel):
|
||||||
"""Deserialized board record."""
|
"""Deserialized board record."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user