From 38a948ac9f6e682b82930b5e33cf48f8c27ef30c Mon Sep 17 00:00:00 2001 From: maryhipp Date: Wed, 26 Jun 2024 14:03:03 -0400 Subject: [PATCH] feat(api): add archived query param to board list endpoint to include them in the response --- invokeai/app/api/routers/boards.py | 4 +- .../board_records/board_records_base.py | 2 + .../board_records/board_records_sqlite.py | 63 +++++++++++++------ invokeai/app/services/boards/boards_base.py | 2 + .../app/services/boards/boards_default.py | 8 +-- 5 files changed, 55 insertions(+), 24 deletions(-) diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py index f9f1a4bb04..d0116ad42f 100644 --- a/invokeai/app/api/routers/boards.py +++ b/invokeai/app/api/routers/boards.py @@ -119,14 +119,16 @@ async def list_boards( all: Optional[bool] = Query(default=None, description="Whether to list all boards"), offset: Optional[int] = Query(default=None, description="The page offset"), limit: Optional[int] = Query(default=None, description="The number of boards per page"), + archived: bool = Query(default=False, description="Whether or not to include archived boards in list"), ) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]: """Gets a list of boards""" if all: - return ApiDependencies.invoker.services.boards.get_all() + return ApiDependencies.invoker.services.boards.get_all(archived) elif offset is not None and limit is not None: return ApiDependencies.invoker.services.boards.get_many( offset, limit, + archived ) else: raise HTTPException( diff --git a/invokeai/app/services/board_records/board_records_base.py b/invokeai/app/services/board_records/board_records_base.py index 30f819618a..149cf9b724 100644 --- a/invokeai/app/services/board_records/board_records_base.py +++ b/invokeai/app/services/board_records/board_records_base.py @@ -43,6 +43,7 @@ class BoardRecordStorageBase(ABC): self, offset: int = 0, limit: int = 10, + archived: bool = False ) -> OffsetPaginatedResults[BoardRecord]: """Gets many board records.""" pass @@ -50,6 +51,7 @@ class BoardRecordStorageBase(ABC): @abstractmethod def get_all( self, + archived: bool = False ) -> list[BoardRecord]: """Gets all board records.""" pass diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py index 85c130e887..efdc215b2a 100644 --- a/invokeai/app/services/board_records/board_records_sqlite.py +++ b/invokeai/app/services/board_records/board_records_sqlite.py @@ -148,33 +148,50 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): def get_many( self, offset: int = 0, - limit: int = 10 + limit: int = 10, + archived: bool = False ) -> OffsetPaginatedResults[BoardRecord]: try: self._lock.acquire() - # Get all the boards - self._cursor.execute( - """--sql + # Build base query + base_query = """ SELECT * FROM boards + {archived_filter} ORDER BY created_at DESC LIMIT ? OFFSET ?; - """, - (limit, offset), - ) + """ + + # Determine archived filter condition + if archived: + archived_filter = "" + else: + archived_filter = "WHERE archived = 0" + + final_query = base_query.format(archived_filter=archived_filter) + + # Execute query to fetch boards + self._cursor.execute(final_query, (limit, offset)) result = cast(list[sqlite3.Row], self._cursor.fetchall()) boards = [deserialize_board_record(dict(r)) for r in result] - # Get the total number of boards - self._cursor.execute( - """--sql - SELECT COUNT(*) - FROM boards - WHERE 1=1; + # Determine count query + if archived: + count_query = """ + SELECT COUNT(*) + FROM boards; """ - ) + else: + count_query = """ + SELECT COUNT(*) + FROM boards + WHERE archived = 0; + """ + + # Execute count query + self._cursor.execute(count_query) count = cast(int, self._cursor.fetchone()[0]) @@ -188,18 +205,26 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): def get_all( self, + archived: bool = False ) -> list[BoardRecord]: try: self._lock.acquire() - # Get all the boards - self._cursor.execute( - """--sql + base_query = """ SELECT * FROM boards + {archived_filter} ORDER BY created_at DESC - """ - ) + """ + + if archived: + archived_filter = "" + else: + archived_filter = "WHERE archived = 0" + + final_query = base_query.format(archived_filter=archived_filter) + + self._cursor.execute(final_query) result = cast(list[sqlite3.Row], self._cursor.fetchall()) boards = [deserialize_board_record(dict(r)) for r in result] diff --git a/invokeai/app/services/boards/boards_base.py b/invokeai/app/services/boards/boards_base.py index 6f90334d53..bcf046484f 100644 --- a/invokeai/app/services/boards/boards_base.py +++ b/invokeai/app/services/boards/boards_base.py @@ -47,6 +47,7 @@ class BoardServiceABC(ABC): self, offset: int = 0, limit: int = 10, + archived: bool = False ) -> OffsetPaginatedResults[BoardDTO]: """Gets many boards.""" pass @@ -54,6 +55,7 @@ class BoardServiceABC(ABC): @abstractmethod def get_all( self, + archived: bool = False ) -> list[BoardDTO]: """Gets all boards.""" pass diff --git a/invokeai/app/services/boards/boards_default.py b/invokeai/app/services/boards/boards_default.py index 5b37d6c7ad..ac8b770230 100644 --- a/invokeai/app/services/boards/boards_default.py +++ b/invokeai/app/services/boards/boards_default.py @@ -48,8 +48,8 @@ class BoardService(BoardServiceABC): def delete(self, board_id: str) -> None: self.__invoker.services.board_records.delete(board_id) - def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]: - board_records = self.__invoker.services.board_records.get_many(offset, limit) + def get_many(self, offset: int = 0, limit: int = 10, archived: bool = False) -> OffsetPaginatedResults[BoardDTO]: + board_records = self.__invoker.services.board_records.get_many(offset, limit, archived) board_dtos = [] for r in board_records.items: cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id) @@ -63,8 +63,8 @@ class BoardService(BoardServiceABC): return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)) - def get_all(self) -> list[BoardDTO]: - board_records = self.__invoker.services.board_records.get_all() + def get_all(self, archived: bool = False) -> list[BoardDTO]: + board_records = self.__invoker.services.board_records.get_all(archived) board_dtos = [] for r in board_records: cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)