mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
board CRUD
This commit is contained in:
parent
257e972599
commit
a1671519d5
@ -2,6 +2,7 @@
|
||||
|
||||
from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services import boards
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
@ -20,6 +21,7 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.boards import SqliteBoardStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
@ -71,6 +73,7 @@ class ApiDependencies:
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
boards = SqliteBoardStorage(db_location)
|
||||
|
||||
images = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
@ -96,6 +99,7 @@ class ApiDependencies:
|
||||
restoration=RestorationServices(config, logger),
|
||||
configuration=config,
|
||||
logger=logger,
|
||||
boards=boards
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
77
invokeai/app/api/routers/boards.py
Normal file
77
invokeai/app/api/routers/boards.py
Normal file
@ -0,0 +1,77 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.boards import BoardRecord, BoardRecordChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
|
||||
|
||||
|
||||
@boards_router.post(
|
||||
"/",
|
||||
operation_id="create_board",
|
||||
responses={
|
||||
201: {"description": "The board was created successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def create_board(
|
||||
board_name: str = Body(description="The name of the board to create"),
|
||||
):
|
||||
"""Creates a board"""
|
||||
try:
|
||||
ApiDependencies.invoker.services.boards.save(board_name=board_name)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
|
||||
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board")
|
||||
async def delete_board(
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
) -> None:
|
||||
"""Deletes a board"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@boards_router.patch(
|
||||
"/{board_id}",
|
||||
operation_id="update_baord"
|
||||
)
|
||||
async def update_baord(
|
||||
id: str = Path(description="The id of the board to update"),
|
||||
board_changes: BoardRecordChanges = Body(
|
||||
description="The changes to apply to the board"
|
||||
),
|
||||
):
|
||||
"""Updates a board"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.boards.update(
|
||||
id, board_changes
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail="Failed to update board")
|
||||
|
||||
@boards_router.get(
|
||||
"/",
|
||||
operation_id="list_boards",
|
||||
response_model=OffsetPaginatedResults[BoardRecord],
|
||||
)
|
||||
async def list_boards(
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of boards per page"),
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets a list of boards"""
|
||||
|
||||
boards = ApiDependencies.invoker.services.boards.get_many(
|
||||
offset,
|
||||
limit,
|
||||
)
|
||||
|
||||
return boards
|
@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config)
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images
|
||||
from .api.routers import sessions, models, images, boards
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
@ -78,6 +78,8 @@ app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
|
172
invokeai/app/services/boards.py
Normal file
172
invokeai/app/services/boards.py
Normal file
@ -0,0 +1,172 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
class BoardRecord(BaseModel):
|
||||
"""Deserialized board record."""
|
||||
|
||||
id: str = Field(description="The unique ID of the board.")
|
||||
name: str = Field(description="The name of the board.")
|
||||
"""The name of the board."""
|
||||
created_at: Union[datetime, str] = Field(
|
||||
description="The created timestamp of the board."
|
||||
)
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(
|
||||
description="The updated timestamp of the board."
|
||||
)
|
||||
|
||||
class BoardRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
name: Optional[str] = Field(
|
||||
description="The board's new name."
|
||||
)
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
"""Raised when an board record is not found."""
|
||||
|
||||
def __init__(self, message="Board record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordSaveException(Exception):
|
||||
"""Raised when an board record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Board record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordDeleteException(Exception):
|
||||
"""Raised when an board record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Board record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
class BoardStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the board record store."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, board_id: str) -> BoardRecord:
|
||||
"""Gets an board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, board_id: str) -> None:
|
||||
"""Deletes a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
):
|
||||
"""Saves a board record."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBoardStorage(BoardStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `board` table."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards(created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_name = old.board_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
):
|
||||
try:
|
||||
board_id = str(uuid.uuid4())
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (id, name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
@ -135,7 +135,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the tables for the `images` database."""
|
||||
"""Creates the `images` table."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
@ -152,6 +152,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
board_id TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
|
@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.config import InvokeAISettings
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
from invokeai.app.services.boards import BoardStorageBase
|
||||
|
||||
|
||||
class InvocationServices:
|
||||
@ -27,6 +28,7 @@ class InvocationServices:
|
||||
restoration: "RestorationServices"
|
||||
configuration: "InvokeAISettings"
|
||||
images: "ImageService"
|
||||
boards: "BoardStorageBase"
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
@ -46,6 +48,7 @@ class InvocationServices:
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: "RestorationServices",
|
||||
configuration: "InvokeAISettings",
|
||||
boards: "BoardStorageBase",
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
@ -58,3 +61,4 @@ class InvocationServices:
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
self.boards = boards
|
||||
|
Loading…
x
Reference in New Issue
Block a user