diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 5599d569d5..efeb778922 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,8 +2,17 @@ from logging import Logger import os +from invokeai.app.services.board_image_record_storage import ( + SqliteBoardImageRecordStorage, +) +from invokeai.app.services.board_images import ( + BoardImagesService, + BoardImagesServiceDependencies, +) +from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage +from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.image_record_storage import SqliteImageRecordStorage -from invokeai.app.services.images import ImageService +from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService @@ -57,7 +66,7 @@ class ApiDependencies: # TODO: build a file/path manager? db_location = config.db_path - db_location.parent.mkdir(parents=True,exist_ok=True) + db_location.parent.mkdir(parents=True, exist_ok=True) graph_execution_manager = SqliteItemStorage[GraphExecutionState]( filename=db_location, table_name="graph_executions" @@ -72,14 +81,40 @@ class ApiDependencies: DiskLatentsStorage(f"{output_folder}/latents") ) + board_record_storage = SqliteBoardRecordStorage(db_location) + board_image_record_storage = SqliteBoardImageRecordStorage(db_location) + + boards = BoardService( + services=BoardServiceDependencies( + board_image_record_storage=board_image_record_storage, + board_record_storage=board_record_storage, + image_record_storage=image_record_storage, + url=urls, + logger=logger, + ) + ) + + board_images = BoardImagesService( + services=BoardImagesServiceDependencies( + board_image_record_storage=board_image_record_storage, + board_record_storage=board_record_storage, + image_record_storage=image_record_storage, + url=urls, + logger=logger, + ) + ) + images = ImageService( - image_record_storage=image_record_storage, - image_file_storage=image_file_storage, - metadata=metadata, - url=urls, - logger=logger, - names=names, - graph_execution_manager=graph_execution_manager, + services=ImageServiceDependencies( + board_image_record_storage=board_image_record_storage, + image_record_storage=image_record_storage, + image_file_storage=image_file_storage, + metadata=metadata, + url=urls, + logger=logger, + names=names, + graph_execution_manager=graph_execution_manager, + ) ) services = InvocationServices( @@ -87,6 +122,8 @@ class ApiDependencies: events=events, latents=latents, images=images, + boards=boards, + board_images=board_images, queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph]( filename=db_location, table_name="graphs" diff --git a/invokeai/app/api/routers/board_images.py b/invokeai/app/api/routers/board_images.py new file mode 100644 index 0000000000..b206ab500d --- /dev/null +++ b/invokeai/app/api/routers/board_images.py @@ -0,0 +1,69 @@ +from fastapi import Body, HTTPException, Path, Query +from fastapi.routing import APIRouter +from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.board_record import BoardDTO +from invokeai.app.services.models.image_record import ImageDTO + +from ..dependencies import ApiDependencies + +board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) + + +@board_images_router.post( + "/", + operation_id="create_board_image", + responses={ + 201: {"description": "The image was added to a board successfully"}, + }, + status_code=201, +) +async def create_board_image( + board_id: str = Body(description="The id of the board to add to"), + image_name: str = Body(description="The name of the image to add"), +): + """Creates a board_image""" + try: + result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name) + return result + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to add to board") + +@board_images_router.delete( + "/", + operation_id="remove_board_image", + responses={ + 201: {"description": "The image was removed from the board successfully"}, + }, + status_code=201, +) +async def remove_board_image( + board_id: str = Body(description="The id of the board"), + image_name: str = Body(description="The name of the image to remove"), +): + """Deletes a board_image""" + try: + result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name) + return result + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to update board") + + + +@board_images_router.get( + "/{board_id}", + operation_id="list_board_images", + response_model=OffsetPaginatedResults[ImageDTO], +) +async def list_board_images( + board_id: str = Path(description="The id of the board"), + offset: int = Query(default=0, description="The page offset"), + limit: int = Query(default=10, description="The number of boards per page"), +) -> OffsetPaginatedResults[ImageDTO]: + """Gets a list of images for a board""" + + results = ApiDependencies.invoker.services.board_images.get_images_for_board( + board_id, + ) + return results + diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py new file mode 100644 index 0000000000..55cd7c8ca2 --- /dev/null +++ b/invokeai/app/api/routers/boards.py @@ -0,0 +1,108 @@ +from typing import Optional, Union +from fastapi import Body, HTTPException, Path, Query +from fastapi.routing import APIRouter +from invokeai.app.services.board_record_storage import BoardChanges +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.board_record import BoardDTO + +from ..dependencies import ApiDependencies + +boards_router = APIRouter(prefix="/v1/boards", tags=["boards"]) + + +@boards_router.post( + "/", + operation_id="create_board", + responses={ + 201: {"description": "The board was created successfully"}, + }, + status_code=201, + response_model=BoardDTO, +) +async def create_board( + board_name: str = Query(description="The name of the board to create"), +) -> BoardDTO: + """Creates a board""" + try: + result = ApiDependencies.invoker.services.boards.create(board_name=board_name) + return result + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to create board") + + +@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO) +async def get_board( + board_id: str = Path(description="The id of board to get"), +) -> BoardDTO: + """Gets a board""" + + try: + result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + return result + except Exception as e: + raise HTTPException(status_code=404, detail="Board not found") + + +@boards_router.patch( + "/{board_id}", + operation_id="update_board", + responses={ + 201: { + "description": "The board was updated successfully", + }, + }, + status_code=201, + response_model=BoardDTO, +) +async def update_board( + board_id: str = Path(description="The id of board to update"), + changes: BoardChanges = Body(description="The changes to apply to the board"), +) -> BoardDTO: + """Updates a board""" + try: + result = ApiDependencies.invoker.services.boards.update( + board_id=board_id, changes=changes + ) + return result + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to update board") + + +@boards_router.delete("/{board_id}", operation_id="delete_board") +async def delete_board( + board_id: str = Path(description="The id of board to delete"), +) -> None: + """Deletes a board""" + + try: + ApiDependencies.invoker.services.boards.delete(board_id=board_id) + except Exception as e: + # TODO: Does this need any exception handling at all? + pass + + +@boards_router.get( + "/", + operation_id="list_boards", + response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]], +) +async def list_boards( + all: Optional[bool] = Query(default=None, description="Whether to list all boards"), + offset: Optional[int] = Query(default=None, description="The page offset"), + limit: Optional[int] = Query( + default=None, description="The number of boards per page" + ), +) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]: + """Gets a list of boards""" + if all: + return ApiDependencies.invoker.services.boards.get_all() + elif offset is not None and limit is not None: + return ApiDependencies.invoker.services.boards.get_many( + offset, + limit, + ) + else: + raise HTTPException( + status_code=400, + detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'", + ) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 11453d97f1..a8c84b81b9 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -221,6 +221,9 @@ async def list_images_with_metadata( is_intermediate: Optional[bool] = Query( default=None, description="Whether to list intermediate images" ), + board_id: Optional[str] = Query( + default=None, description="The board id to filter by" + ), offset: int = Query(default=0, description="The page offset"), limit: int = Query(default=10, description="The number of images per page"), ) -> OffsetPaginatedResults[ImageDTO]: @@ -232,6 +235,7 @@ async def list_images_with_metadata( image_origin, categories, is_intermediate, + board_id, ) return image_dtos diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 0abcc19dcf..50d645eb57 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel): info: DiffusersModelInfo = Field(description="The converted model info") class ModelsList(BaseModel): - models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend - #models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]] + models: list[MODEL_CONFIGS] @models_router.get( @@ -72,10 +71,10 @@ class ModelsList(BaseModel): responses={200: {"model": ModelsList }}, ) async def list_models( - base_model: BaseModelType = Query( + base_model: Optional[BaseModelType] = Query( default=None, description="Base model" ), - model_type: ModelType = Query( + model_type: Optional[ModelType] = Query( default=None, description="The type of model to get" ), ) -> ModelsList: diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 49e9f92cc7..e14c58bab7 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config) import invokeai.frontend.web as web_dir from .api.dependencies import ApiDependencies -from .api.routers import sessions, models, images +from .api.routers import sessions, models, images, boards, board_images from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation @@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api") app.include_router(images.images_router, prefix="/api") +app.include_router(boards.boards_router, prefix="/api") + +app.include_router(board_images.board_images_router, prefix="/api") + # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? def custom_openapi(): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index bd9ab67271..6bc4eb93a4 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput): #fmt: on -class SD1ModelLoaderInvocation(BaseInvocation): - """Loading submodels of selected model.""" +class PipelineModelField(BaseModel): + """Pipeline model field""" - type: Literal["sd1_model_loader"] = "sd1_model_loader" + model_name: str = Field(description="Name of the model") + base_model: BaseModelType = Field(description="Base model") - model_name: str = Field(default="", description="Model to load") + +class PipelineModelLoaderInvocation(BaseInvocation): + """Loads a pipeline model, outputting its submodels.""" + + type: Literal["pipeline_model_loader"] = "pipeline_model_loader" + + model: PipelineModelField = Field(description="The model to load") # TODO: precision? # Schema customisation @@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation): "ui": { "tags": ["model", "loader"], "type_hints": { - "model_name": "model" # TODO: rename to model_name? + "model": "model" } }, } def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = BaseModelType.StableDiffusion1 # TODO: + base_model = self.model.base_model + model_name = self.model.model_name + model_type = ModelType.Pipeline # TODO: not found exceptions if not context.services.model_manager.model_exists( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, ): - raise Exception(f"Unkown model name: {self.model_name}!") + raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") """ if not context.services.model_manager.model_exists( @@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation): return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.TextEncoder, ), loras=[], ), vae=VaeField( vae=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.Vae, - ), - ) - ) - -# TODO: optimize(less code copy) -class SD2ModelLoaderInvocation(BaseInvocation): - """Loading submodels of selected model.""" - - type: Literal["sd2_model_loader"] = "sd2_model_loader" - - model_name: str = Field(default="", description="Model to load") - # TODO: precision? - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["model", "loader"], - "type_hints": { - "model_name": "model" # TODO: rename to model_name? - } - }, - } - - def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - - base_model = BaseModelType.StableDiffusion2 # TODO: - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - ): - raise Exception(f"Unkown model name: {self.model_name}!") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ - - - return ModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.TextEncoder, - ), - loras=[], - ), - vae=VaeField( - vae=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.Vae, ), ) diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py new file mode 100644 index 0000000000..7aff41860c --- /dev/null +++ b/invokeai/app/services/board_image_record_storage.py @@ -0,0 +1,254 @@ +from abc import ABC, abstractmethod +import sqlite3 +import threading +from typing import Union, cast +from invokeai.app.services.board_record_storage import BoardRecord + +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.image_record import ( + ImageRecord, + deserialize_image_record, +) + + +class BoardImageRecordStorageBase(ABC): + """Abstract base class for the one-to-many board-image relationship record storage.""" + + @abstractmethod + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Adds an image to a board.""" + pass + + @abstractmethod + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Removes an image from a board.""" + pass + + @abstractmethod + def get_images_for_board( + self, + board_id: str, + ) -> OffsetPaginatedResults[ImageRecord]: + """Gets images for a board.""" + pass + + @abstractmethod + def get_board_for_image( + self, + image_name: str, + ) -> Union[str, None]: + """Gets an image's board id, if it has one.""" + pass + + @abstractmethod + def get_image_count_for_board( + self, + board_id: str, + ) -> int: + """Gets the number of images for a board.""" + pass + + +class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): + _filename: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.Lock + + def __init__(self, filename: str) -> None: + super().__init__() + self._filename = filename + self._conn = sqlite3.connect(filename, check_same_thread=False) + # Enable row factory to get rows as dictionaries (must be done before making the cursor!) + self._conn.row_factory = sqlite3.Row + self._cursor = self._conn.cursor() + self._lock = threading.Lock() + + try: + self._lock.acquire() + # Enable foreign keys + self._conn.execute("PRAGMA foreign_keys = ON;") + self._create_tables() + self._conn.commit() + finally: + self._lock.release() + + def _create_tables(self) -> None: + """Creates the `board_images` junction table.""" + + # Create the `board_images` junction table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS board_images ( + board_id TEXT NOT NULL, + image_name TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME, + -- enforce one-to-many relationship between boards and images using PK + -- (we can extend this to many-to-many later) + PRIMARY KEY (image_name), + FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE, + FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE + ); + """ + ) + + # Add index for board id + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id); + """ + ) + + # Add index for board id, sorted by created_at + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at); + """ + ) + + # Add trigger for `updated_at`. + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at + AFTER UPDATE + ON board_images FOR EACH ROW + BEGIN + UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE board_id = old.board_id AND image_name = old.image_name; + END; + """ + ) + + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT INTO board_images (board_id, image_name) + VALUES (?, ?) + ON CONFLICT (image_name) DO UPDATE SET board_id = ?; + """, + (board_id, image_name, board_id), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE FROM board_images + WHERE board_id = ? AND image_name = ?; + """, + (board_id, image_name), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def get_images_for_board( + self, + board_id: str, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[ImageRecord]: + # TODO: this isn't paginated yet? + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT images.* + FROM board_images + INNER JOIN images ON board_images.image_name = images.image_name + WHERE board_images.board_id = ? + ORDER BY board_images.updated_at DESC; + """, + (board_id,), + ) + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + images = list(map(lambda r: deserialize_image_record(dict(r)), result)) + + self._cursor.execute( + """--sql + SELECT COUNT(*) FROM images WHERE 1=1; + """ + ) + count = cast(int, self._cursor.fetchone()[0]) + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + return OffsetPaginatedResults( + items=images, offset=offset, limit=limit, total=count + ) + + def get_board_for_image( + self, + image_name: str, + ) -> Union[str, None]: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT board_id + FROM board_images + WHERE image_name = ?; + """, + (image_name,), + ) + result = self._cursor.fetchone() + if result is None: + return None + return cast(str, result[0]) + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def get_image_count_for_board(self, board_id: str) -> int: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT COUNT(*) FROM board_images WHERE board_id = ?; + """, + (board_id,), + ) + count = cast(int, self._cursor.fetchone()[0]) + return count + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() diff --git a/invokeai/app/services/board_images.py b/invokeai/app/services/board_images.py new file mode 100644 index 0000000000..072effbfae --- /dev/null +++ b/invokeai/app/services/board_images.py @@ -0,0 +1,142 @@ +from abc import ABC, abstractmethod +from logging import Logger +from typing import List, Union +from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase +from invokeai.app.services.board_record_storage import ( + BoardRecord, + BoardRecordStorageBase, +) + +from invokeai.app.services.image_record_storage import ( + ImageRecordStorageBase, + OffsetPaginatedResults, +) +from invokeai.app.services.models.board_record import BoardDTO +from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto +from invokeai.app.services.urls import UrlServiceBase + + +class BoardImagesServiceABC(ABC): + """High-level service for board-image relationship management.""" + + @abstractmethod + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Adds an image to a board.""" + pass + + @abstractmethod + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Removes an image from a board.""" + pass + + @abstractmethod + def get_images_for_board( + self, + board_id: str, + ) -> OffsetPaginatedResults[ImageDTO]: + """Gets images for a board.""" + pass + + @abstractmethod + def get_board_for_image( + self, + image_name: str, + ) -> Union[str, None]: + """Gets an image's board id, if it has one.""" + pass + + +class BoardImagesServiceDependencies: + """Service dependencies for the BoardImagesService.""" + + board_image_records: BoardImageRecordStorageBase + board_records: BoardRecordStorageBase + image_records: ImageRecordStorageBase + urls: UrlServiceBase + logger: Logger + + def __init__( + self, + board_image_record_storage: BoardImageRecordStorageBase, + image_record_storage: ImageRecordStorageBase, + board_record_storage: BoardRecordStorageBase, + url: UrlServiceBase, + logger: Logger, + ): + self.board_image_records = board_image_record_storage + self.image_records = image_record_storage + self.board_records = board_record_storage + self.urls = url + self.logger = logger + + +class BoardImagesService(BoardImagesServiceABC): + _services: BoardImagesServiceDependencies + + def __init__(self, services: BoardImagesServiceDependencies): + self._services = services + + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + self._services.board_image_records.add_image_to_board(board_id, image_name) + + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + self._services.board_image_records.remove_image_from_board(board_id, image_name) + + def get_images_for_board( + self, + board_id: str, + ) -> OffsetPaginatedResults[ImageDTO]: + image_records = self._services.board_image_records.get_images_for_board( + board_id + ) + image_dtos = list( + map( + lambda r: image_record_to_dto( + r, + self._services.urls.get_image_url(r.image_name), + self._services.urls.get_image_url(r.image_name, True), + board_id, + ), + image_records.items, + ) + ) + return OffsetPaginatedResults[ImageDTO]( + items=image_dtos, + offset=image_records.offset, + limit=image_records.limit, + total=image_records.total, + ) + + def get_board_for_image( + self, + image_name: str, + ) -> Union[str, None]: + board_id = self._services.board_image_records.get_board_for_image(image_name) + return board_id + + +def board_record_to_dto( + board_record: BoardRecord, cover_image_name: str | None, image_count: int +) -> BoardDTO: + """Converts a board record to a board DTO.""" + return BoardDTO( + **board_record.dict(exclude={'cover_image_name'}), + cover_image_name=cover_image_name, + image_count=image_count, + ) diff --git a/invokeai/app/services/board_record_storage.py b/invokeai/app/services/board_record_storage.py new file mode 100644 index 0000000000..15ea9cc5a7 --- /dev/null +++ b/invokeai/app/services/board_record_storage.py @@ -0,0 +1,329 @@ +from abc import ABC, abstractmethod +from typing import Optional, cast +import sqlite3 +import threading +from typing import Optional, Union +import uuid +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.board_record import ( + BoardRecord, + deserialize_board_record, +) + +from pydantic import BaseModel, Field, Extra + + +class BoardChanges(BaseModel, extra=Extra.forbid): + board_name: Optional[str] = Field(description="The board's new name.") + cover_image_name: Optional[str] = Field( + description="The name of the board's new cover image." + ) + + +class BoardRecordNotFoundException(Exception): + """Raised when an board record is not found.""" + + def __init__(self, message="Board record not found"): + super().__init__(message) + + +class BoardRecordSaveException(Exception): + """Raised when an board record cannot be saved.""" + + def __init__(self, message="Board record not saved"): + super().__init__(message) + + +class BoardRecordDeleteException(Exception): + """Raised when an board record cannot be deleted.""" + + def __init__(self, message="Board record not deleted"): + super().__init__(message) + + +class BoardRecordStorageBase(ABC): + """Low-level service responsible for interfacing with the board record store.""" + + @abstractmethod + def delete(self, board_id: str) -> None: + """Deletes a board record.""" + pass + + @abstractmethod + def save( + self, + board_name: str, + ) -> BoardRecord: + """Saves a board record.""" + pass + + @abstractmethod + def get( + self, + board_id: str, + ) -> BoardRecord: + """Gets a board record.""" + pass + + @abstractmethod + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardRecord: + """Updates a board record.""" + pass + + @abstractmethod + def get_many( + self, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardRecord]: + """Gets many board records.""" + pass + + @abstractmethod + def get_all( + self, + ) -> list[BoardRecord]: + """Gets all board records.""" + pass + + +class SqliteBoardRecordStorage(BoardRecordStorageBase): + _filename: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.Lock + + def __init__(self, filename: str) -> None: + super().__init__() + self._filename = filename + self._conn = sqlite3.connect(filename, check_same_thread=False) + # Enable row factory to get rows as dictionaries (must be done before making the cursor!) + self._conn.row_factory = sqlite3.Row + self._cursor = self._conn.cursor() + self._lock = threading.Lock() + + try: + self._lock.acquire() + # Enable foreign keys + self._conn.execute("PRAGMA foreign_keys = ON;") + self._create_tables() + self._conn.commit() + finally: + self._lock.release() + + def _create_tables(self) -> None: + """Creates the `boards` table and `board_images` junction table.""" + + # Create the `boards` table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS boards ( + board_id TEXT NOT NULL PRIMARY KEY, + board_name TEXT NOT NULL, + cover_image_name TEXT, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME, + FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL + ); + """ + ) + + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at); + """ + ) + + # Add trigger for `updated_at`. + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at + AFTER UPDATE + ON boards FOR EACH ROW + BEGIN + UPDATE boards SET updated_at = current_timestamp + WHERE board_id = old.board_id; + END; + """ + ) + + def delete(self, board_id: str) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE FROM boards + WHERE board_id = ?; + """, + (board_id,), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordDeleteException from e + except Exception as e: + self._conn.rollback() + raise BoardRecordDeleteException from e + finally: + self._lock.release() + + def save( + self, + board_name: str, + ) -> BoardRecord: + try: + board_id = str(uuid.uuid4()) + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO boards (board_id, board_name) + VALUES (?, ?); + """, + (board_id, board_name), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordSaveException from e + finally: + self._lock.release() + return self.get(board_id) + + def get( + self, + board_id: str, + ) -> BoardRecord: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM boards + WHERE board_id = ?; + """, + (board_id,), + ) + + result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordNotFoundException from e + finally: + self._lock.release() + if result is None: + raise BoardRecordNotFoundException + return BoardRecord(**dict(result)) + + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardRecord: + try: + self._lock.acquire() + + # Change the name of a board + if changes.board_name is not None: + self._cursor.execute( + f"""--sql + UPDATE boards + SET board_name = ? + WHERE board_id = ?; + """, + (changes.board_name, board_id), + ) + + # Change the cover image of a board + if changes.cover_image_name is not None: + self._cursor.execute( + f"""--sql + UPDATE boards + SET cover_image_name = ? + WHERE board_id = ?; + """, + (changes.cover_image_name, board_id), + ) + + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordSaveException from e + finally: + self._lock.release() + return self.get(board_id) + + def get_many( + self, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardRecord]: + try: + self._lock.acquire() + + # Get all the boards + self._cursor.execute( + """--sql + SELECT * + FROM boards + ORDER BY created_at DESC + LIMIT ? OFFSET ?; + """, + (limit, offset), + ) + + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + boards = list(map(lambda r: deserialize_board_record(dict(r)), result)) + + # Get the total number of boards + self._cursor.execute( + """--sql + SELECT COUNT(*) + FROM boards + WHERE 1=1; + """ + ) + + count = cast(int, self._cursor.fetchone()[0]) + + return OffsetPaginatedResults[BoardRecord]( + items=boards, offset=offset, limit=limit, total=count + ) + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def get_all( + self, + ) -> list[BoardRecord]: + try: + self._lock.acquire() + + # Get all the boards + self._cursor.execute( + """--sql + SELECT * + FROM boards + ORDER BY created_at DESC + """ + ) + + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + boards = list(map(lambda r: deserialize_board_record(dict(r)), result)) + + return boards + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() diff --git a/invokeai/app/services/boards.py b/invokeai/app/services/boards.py new file mode 100644 index 0000000000..9361322e6c --- /dev/null +++ b/invokeai/app/services/boards.py @@ -0,0 +1,185 @@ +from abc import ABC, abstractmethod + +from logging import Logger +from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase +from invokeai.app.services.board_images import board_record_to_dto + +from invokeai.app.services.board_record_storage import ( + BoardChanges, + BoardRecordStorageBase, +) +from invokeai.app.services.image_record_storage import ( + ImageRecordStorageBase, + OffsetPaginatedResults, +) +from invokeai.app.services.models.board_record import BoardDTO +from invokeai.app.services.urls import UrlServiceBase + + +class BoardServiceABC(ABC): + """High-level service for board management.""" + + @abstractmethod + def create( + self, + board_name: str, + ) -> BoardDTO: + """Creates a board.""" + pass + + @abstractmethod + def get_dto( + self, + board_id: str, + ) -> BoardDTO: + """Gets a board.""" + pass + + @abstractmethod + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardDTO: + """Updates a board.""" + pass + + @abstractmethod + def delete( + self, + board_id: str, + ) -> None: + """Deletes a board.""" + pass + + @abstractmethod + def get_many( + self, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardDTO]: + """Gets many boards.""" + pass + + @abstractmethod + def get_all( + self, + ) -> list[BoardDTO]: + """Gets all boards.""" + pass + + +class BoardServiceDependencies: + """Service dependencies for the BoardService.""" + + board_image_records: BoardImageRecordStorageBase + board_records: BoardRecordStorageBase + image_records: ImageRecordStorageBase + urls: UrlServiceBase + logger: Logger + + def __init__( + self, + board_image_record_storage: BoardImageRecordStorageBase, + image_record_storage: ImageRecordStorageBase, + board_record_storage: BoardRecordStorageBase, + url: UrlServiceBase, + logger: Logger, + ): + self.board_image_records = board_image_record_storage + self.image_records = image_record_storage + self.board_records = board_record_storage + self.urls = url + self.logger = logger + + +class BoardService(BoardServiceABC): + _services: BoardServiceDependencies + + def __init__(self, services: BoardServiceDependencies): + self._services = services + + def create( + self, + board_name: str, + ) -> BoardDTO: + board_record = self._services.board_records.save(board_name) + return board_record_to_dto(board_record, None, 0) + + def get_dto(self, board_id: str) -> BoardDTO: + board_record = self._services.board_records.get(board_id) + cover_image = self._services.image_records.get_most_recent_image_for_board( + board_record.board_id + ) + if cover_image: + cover_image_name = cover_image.image_name + else: + cover_image_name = None + image_count = self._services.board_image_records.get_image_count_for_board( + board_id + ) + return board_record_to_dto(board_record, cover_image_name, image_count) + + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardDTO: + board_record = self._services.board_records.update(board_id, changes) + cover_image = self._services.image_records.get_most_recent_image_for_board( + board_record.board_id + ) + if cover_image: + cover_image_name = cover_image.image_name + else: + cover_image_name = None + + image_count = self._services.board_image_records.get_image_count_for_board( + board_id + ) + return board_record_to_dto(board_record, cover_image_name, image_count) + + def delete(self, board_id: str) -> None: + self._services.board_records.delete(board_id) + + def get_many( + self, offset: int = 0, limit: int = 10 + ) -> OffsetPaginatedResults[BoardDTO]: + board_records = self._services.board_records.get_many(offset, limit) + board_dtos = [] + for r in board_records.items: + cover_image = self._services.image_records.get_most_recent_image_for_board( + r.board_id + ) + if cover_image: + cover_image_name = cover_image.image_name + else: + cover_image_name = None + + image_count = self._services.board_image_records.get_image_count_for_board( + r.board_id + ) + board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) + + return OffsetPaginatedResults[BoardDTO]( + items=board_dtos, offset=offset, limit=limit, total=len(board_dtos) + ) + + def get_all(self) -> list[BoardDTO]: + board_records = self._services.board_records.get_all() + board_dtos = [] + for r in board_records: + cover_image = self._services.image_records.get_most_recent_image_for_board( + r.board_id + ) + if cover_image: + cover_image_name = cover_image.image_name + else: + cover_image_name = None + + image_count = self._services.board_image_records.get_image_count_for_board( + r.board_id + ) + board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) + + return board_dtos \ No newline at end of file diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 30b379ed8b..c34d2ca5c8 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageRecord]: """Gets a page of image records.""" pass @@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC): """Saves an image record.""" pass + @abstractmethod + def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None: + """Gets the most recent image for a board.""" + pass + class SqliteImageRecordStorage(ImageRecordStorageBase): _filename: str @@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): self._lock.release() def _create_tables(self) -> None: - """Creates the tables for the `images` database.""" + """Creates the `images` table.""" # Create the `images` table. self._cursor.execute( @@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id TEXT, metadata TEXT, is_intermediate BOOLEAN DEFAULT FALSE, + board_id TEXT, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), @@ -190,7 +197,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): AFTER UPDATE ON images FOR EACH ROW BEGIN - UPDATE images SET updated_at = current_timestamp + UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') WHERE image_name = old.image_name; END; """ @@ -259,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """, (changes.is_intermediate, image_name), ) + self._conn.commit() except sqlite3.Error as e: self._conn.rollback() @@ -273,38 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageRecord]: try: self._lock.acquire() # Manually build two queries - one for the count, one for the records + count_query = """--sql + SELECT COUNT(*) + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1 + """ - count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n""" - images_query = f"""SELECT * FROM images WHERE 1=1\n""" + images_query = """--sql + SELECT images.* + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1 + """ query_conditions = "" query_params = [] if image_origin is not None: - query_conditions += f"""AND image_origin = ?\n""" + query_conditions += """--sql + AND images.image_origin = ? + """ query_params.append(image_origin.value) if categories is not None: - ## Convert the enum values to unique list of strings + # Convert the enum values to unique list of strings category_strings = list(map(lambda c: c.value, set(categories))) # Create the correct length of placeholders placeholders = ",".join("?" * len(category_strings)) - query_conditions += f"AND image_category IN ( {placeholders} )\n" + + query_conditions += f"""--sql + AND images.image_category IN ( {placeholders} ) + """ # Unpack the included categories into the query params for c in category_strings: query_params.append(c) if is_intermediate is not None: - query_conditions += f"""AND is_intermediate = ?\n""" + query_conditions += """--sql + AND images.is_intermediate = ? + """ + query_params.append(is_intermediate) - query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n""" + if board_id is not None: + query_conditions += """--sql + AND board_images.board_id = ? + """ + + query_params.append(board_id) + + query_pagination = """--sql + ORDER BY images.created_at DESC LIMIT ? OFFSET ? + """ # Final images query with pagination images_query += query_conditions + query_pagination + ";" @@ -321,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): count_query += query_conditions + ";" count_params = query_params.copy() self._cursor.execute(count_query, count_params) - count = self._cursor.fetchone()[0] + count = cast(int, self._cursor.fetchone()[0]) except sqlite3.Error as e: self._conn.rollback() raise e @@ -412,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): raise ImageRecordSaveException from e finally: self._lock.release() + + def get_most_recent_image_for_board( + self, board_id: str + ) -> Union[ImageRecord, None]: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT images.* + FROM images + JOIN board_images ON images.image_name = board_images.image_name + WHERE board_images.board_id = ? + ORDER BY images.created_at DESC + LIMIT 1; + """, + (board_id,), + ) + + result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) + finally: + self._lock.release() + if result is None: + return None + + return deserialize_image_record(dict(result)) diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 9f7188f607..542f874f1d 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -10,6 +10,7 @@ from invokeai.app.models.image import ( InvalidOriginException, ) from invokeai.app.models.metadata import ImageMetadata +from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.image_record_storage import ( ImageRecordDeleteException, ImageRecordNotFoundException, @@ -49,7 +50,7 @@ class ImageServiceABC(ABC): image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, - intermediate: bool = False, + is_intermediate: bool = False, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass @@ -79,7 +80,7 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_path(self, image_name: str) -> str: + def get_path(self, image_name: str, thumbnail: bool = False) -> str: """Gets an image's path.""" pass @@ -101,6 +102,7 @@ class ImageServiceABC(ABC): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageDTO]: """Gets a paginated list of image DTOs.""" pass @@ -114,8 +116,9 @@ class ImageServiceABC(ABC): class ImageServiceDependencies: """Service dependencies for the ImageService.""" - records: ImageRecordStorageBase - files: ImageFileStorageBase + image_records: ImageRecordStorageBase + image_files: ImageFileStorageBase + board_image_records: BoardImageRecordStorageBase metadata: MetadataServiceBase urls: UrlServiceBase logger: Logger @@ -126,14 +129,16 @@ class ImageServiceDependencies: self, image_record_storage: ImageRecordStorageBase, image_file_storage: ImageFileStorageBase, + board_image_record_storage: BoardImageRecordStorageBase, metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, names: NameServiceBase, graph_execution_manager: ItemStorageABC["GraphExecutionState"], ): - self.records = image_record_storage - self.files = image_file_storage + self.image_records = image_record_storage + self.image_files = image_file_storage + self.board_image_records = board_image_record_storage self.metadata = metadata self.urls = url self.logger = logger @@ -144,25 +149,8 @@ class ImageServiceDependencies: class ImageService(ImageServiceABC): _services: ImageServiceDependencies - def __init__( - self, - image_record_storage: ImageRecordStorageBase, - image_file_storage: ImageFileStorageBase, - metadata: MetadataServiceBase, - url: UrlServiceBase, - logger: Logger, - names: NameServiceBase, - graph_execution_manager: ItemStorageABC["GraphExecutionState"], - ): - self._services = ImageServiceDependencies( - image_record_storage=image_record_storage, - image_file_storage=image_file_storage, - metadata=metadata, - url=url, - logger=logger, - names=names, - graph_execution_manager=graph_execution_manager, - ) + def __init__(self, services: ImageServiceDependencies): + self._services = services def create( self, @@ -187,7 +175,7 @@ class ImageService(ImageServiceABC): try: # TODO: Consider using a transaction here to ensure consistency between storage and database - created_at = self._services.records.save( + self._services.image_records.save( # Non-nullable fields image_name=image_name, image_origin=image_origin, @@ -202,35 +190,15 @@ class ImageService(ImageServiceABC): metadata=metadata, ) - self._services.files.save( + self._services.image_files.save( image_name=image_name, image=image, metadata=metadata, ) - image_url = self._services.urls.get_image_url(image_name) - thumbnail_url = self._services.urls.get_image_url(image_name, True) + image_dto = self.get_dto(image_name) - return ImageDTO( - # Non-nullable fields - image_name=image_name, - image_origin=image_origin, - image_category=image_category, - width=width, - height=height, - # Nullable fields - node_id=node_id, - session_id=session_id, - metadata=metadata, - # Meta fields - created_at=created_at, - updated_at=created_at, # this is always the same as the created_at at this time - deleted_at=None, - is_intermediate=is_intermediate, - # Extra non-nullable fields for DTO - image_url=image_url, - thumbnail_url=thumbnail_url, - ) + return image_dto except ImageRecordSaveException: self._services.logger.error("Failed to save image record") raise @@ -247,7 +215,7 @@ class ImageService(ImageServiceABC): changes: ImageRecordChanges, ) -> ImageDTO: try: - self._services.records.update(image_name, changes) + self._services.image_records.update(image_name, changes) return self.get_dto(image_name) except ImageRecordSaveException: self._services.logger.error("Failed to update image record") @@ -258,7 +226,7 @@ class ImageService(ImageServiceABC): def get_pil_image(self, image_name: str) -> PILImageType: try: - return self._services.files.get(image_name) + return self._services.image_files.get(image_name) except ImageFileNotFoundException: self._services.logger.error("Failed to get image file") raise @@ -268,7 +236,7 @@ class ImageService(ImageServiceABC): def get_record(self, image_name: str) -> ImageRecord: try: - return self._services.records.get(image_name) + return self._services.image_records.get(image_name) except ImageRecordNotFoundException: self._services.logger.error("Image record not found") raise @@ -278,12 +246,13 @@ class ImageService(ImageServiceABC): def get_dto(self, image_name: str) -> ImageDTO: try: - image_record = self._services.records.get(image_name) + image_record = self._services.image_records.get(image_name) image_dto = image_record_to_dto( image_record, self._services.urls.get_image_url(image_name), self._services.urls.get_image_url(image_name, True), + self._services.board_image_records.get_board_for_image(image_name), ) return image_dto @@ -296,14 +265,14 @@ class ImageService(ImageServiceABC): def get_path(self, image_name: str, thumbnail: bool = False) -> str: try: - return self._services.files.get_path(image_name, thumbnail) + return self._services.image_files.get_path(image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e def validate_path(self, path: str) -> bool: try: - return self._services.files.validate_path(path) + return self._services.image_files.validate_path(path) except Exception as e: self._services.logger.error("Problem validating image path") raise e @@ -322,14 +291,16 @@ class ImageService(ImageServiceABC): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageDTO]: try: - results = self._services.records.get_many( + results = self._services.image_records.get_many( offset, limit, image_origin, categories, is_intermediate, + board_id, ) image_dtos = list( @@ -338,6 +309,9 @@ class ImageService(ImageServiceABC): r, self._services.urls.get_image_url(r.image_name), self._services.urls.get_image_url(r.image_name, True), + self._services.board_image_records.get_board_for_image( + r.image_name + ), ), results.items, ) @@ -355,8 +329,8 @@ class ImageService(ImageServiceABC): def delete(self, image_name: str): try: - self._services.files.delete(image_name) - self._services.records.delete(image_name) + self._services.image_files.delete(image_name) + self._services.image_records.delete(image_name) except ImageRecordDeleteException: self._services.logger.error(f"Failed to delete image record") raise diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 1f910253e5..10d1d91920 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -4,7 +4,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from logging import Logger - from invokeai.app.services.images import ImageService + from invokeai.app.services.board_images import BoardImagesServiceABC + from invokeai.app.services.boards import BoardServiceABC + from invokeai.app.services.images import ImageServiceABC from invokeai.backend import ModelManager from invokeai.app.services.events import EventServiceBase from invokeai.app.services.latent_storage import LatentsStorageBase @@ -26,9 +28,9 @@ class InvocationServices: model_manager: "ModelManager" restoration: "RestorationServices" configuration: "InvokeAISettings" - images: "ImageService" - - # NOTE: we must forward-declare any types that include invocations, since invocations can use services + images: "ImageServiceABC" + boards: "BoardServiceABC" + board_images: "BoardImagesServiceABC" graph_library: "ItemStorageABC"["LibraryGraph"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] processor: "InvocationProcessorABC" @@ -39,7 +41,9 @@ class InvocationServices: events: "EventServiceBase", logger: "Logger", latents: "LatentsStorageBase", - images: "ImageService", + images: "ImageServiceABC", + boards: "BoardServiceABC", + board_images: "BoardImagesServiceABC", queue: "InvocationQueueABC", graph_library: "ItemStorageABC"["LibraryGraph"], graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], @@ -52,9 +56,12 @@ class InvocationServices: self.logger = logger self.latents = latents self.images = images + self.boards = boards + self.board_images = board_images self.queue = queue self.graph_library = graph_library self.graph_execution_manager = graph_execution_manager self.processor = processor self.restoration = restoration self.configuration = configuration + self.boards = boards diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 8956b55139..8b46b17ad0 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -5,7 +5,7 @@ from __future__ import annotations import torch from abc import ABC, abstractmethod from pathlib import Path -from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING +from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING from dataclasses import dataclass from invokeai.backend.model_management.model_manager import ( @@ -273,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase): self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> dict: + ) -> list[dict]: + # ) -> dict: """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } + Return a list of models. """ return self.mgr.list_models(base_model, model_type) diff --git a/invokeai/app/services/models/board_record.py b/invokeai/app/services/models/board_record.py new file mode 100644 index 0000000000..bf5401b209 --- /dev/null +++ b/invokeai/app/services/models/board_record.py @@ -0,0 +1,62 @@ +from typing import Optional, Union +from datetime import datetime +from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr +from invokeai.app.util.misc import get_iso_timestamp + + +class BoardRecord(BaseModel): + """Deserialized board record.""" + + board_id: str = Field(description="The unique ID of the board.") + """The unique ID of the board.""" + board_name: str = Field(description="The name of the board.") + """The name of the board.""" + created_at: Union[datetime, str] = Field( + description="The created timestamp of the board." + ) + """The created timestamp of the image.""" + updated_at: Union[datetime, str] = Field( + description="The updated timestamp of the board." + ) + """The updated timestamp of the image.""" + deleted_at: Union[datetime, str, None] = Field( + description="The deleted timestamp of the board." + ) + """The updated timestamp of the image.""" + cover_image_name: Optional[str] = Field( + description="The name of the cover image of the board." + ) + """The name of the cover image of the board.""" + + +class BoardDTO(BoardRecord): + """Deserialized board record with cover image URL and image count.""" + + cover_image_name: Optional[str] = Field( + description="The name of the board's cover image." + ) + """The URL of the thumbnail of the most recent image in the board.""" + image_count: int = Field(description="The number of images in the board.") + """The number of images in the board.""" + + +def deserialize_board_record(board_dict: dict) -> BoardRecord: + """Deserializes a board record.""" + + # Retrieve all the values, setting "reasonable" defaults if they are not present. + + board_id = board_dict.get("board_id", "unknown") + board_name = board_dict.get("board_name", "unknown") + cover_image_name = board_dict.get("cover_image_name", "unknown") + created_at = board_dict.get("created_at", get_iso_timestamp()) + updated_at = board_dict.get("updated_at", get_iso_timestamp()) + deleted_at = board_dict.get("deleted_at", get_iso_timestamp()) + + return BoardRecord( + board_id=board_id, + board_name=board_name, + cover_image_name=cover_image_name, + created_at=created_at, + updated_at=updated_at, + deleted_at=deleted_at, + ) diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index d971d65916..cc02016cf9 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel): class ImageDTO(ImageRecord, ImageUrlsDTO): - """Deserialized image record, enriched for the frontend with URLs.""" + """Deserialized image record, enriched for the frontend.""" + board_id: Union[str, None] = Field( + description="The id of the board the image belongs to, if one exists." + ) + """The id of the board the image belongs to, if one exists.""" pass def image_record_to_dto( - image_record: ImageRecord, image_url: str, thumbnail_url: str + image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None] ) -> ImageDTO: """Converts an image record to an image DTO.""" return ImageDTO( **image_record.dict(), image_url=image_url, thumbnail_url=thumbnail_url, + board_id=board_id, ) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 9a8c7e64c6..f9a66a87dd 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -473,9 +473,9 @@ class ModelManager(object): self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, - ) -> Dict[str, Dict[str, str]]: + ) -> list[dict]: """ - Return a dict of models, in format [base_model][model_type][model_name] + Return a list of models. Please use model_manager.models() to get all the model names, model_manager.model_info('model-name') to get the stanza for the model @@ -483,7 +483,7 @@ class ModelManager(object): object derived from models.yaml """ - models = dict() + models = [] for model_key in sorted(self.models, key=str.casefold): model_config = self.models[model_key] @@ -493,20 +493,16 @@ class ModelManager(object): if model_type is not None and cur_model_type != model_type: continue - if cur_base_model not in models: - models[cur_base_model] = dict() - if cur_model_type not in models[cur_base_model]: - models[cur_base_model][cur_model_type] = dict() - - models[cur_base_model][cur_model_type][cur_model_name] = dict( + model_dict = dict( **model_config.dict(exclude_defaults=True), - # OpenAPIModelInfoBase name=cur_model_name, base_model=cur_base_model, type=cur_model_type, ) + models.append(model_dict) + return models def print_models(self) -> None: diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index ddc6dace27..55fcc97745 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys'; import Toaster from './Toaster'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; +import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; +import { useListModelsQuery } from 'services/apiSlice'; const DEFAULT_CONFIG = {}; @@ -45,6 +47,18 @@ const App = ({ const isApplicationReady = useIsApplicationReady(); + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); + const { data: controlnetModels } = useListModelsQuery({ + model_type: 'controlnet', + }); + const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' }); + const { data: loraModels } = useListModelsQuery({ model_type: 'lora' }); + const { data: embeddingModels } = useListModelsQuery({ + model_type: 'embedding', + }); + const [loadingOverridden, setLoadingOverridden] = useState(false); const dispatch = useAppDispatch(); @@ -143,6 +157,7 @@ const App = ({ + diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 0537d1de2a..141e62652d 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -21,6 +21,8 @@ import { DeleteImageContext, DeleteImageContextProvider, } from 'app/contexts/DeleteImageContext'; +import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; +import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext'; const App = lazy(() => import('./App')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); @@ -76,11 +78,13 @@ const InvokeAIUI = ({ - + + + diff --git a/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx b/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx new file mode 100644 index 0000000000..f5a856d3d8 --- /dev/null +++ b/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx @@ -0,0 +1,89 @@ +import { useDisclosure } from '@chakra-ui/react'; +import { PropsWithChildren, createContext, useCallback, useState } from 'react'; +import { ImageDTO } from 'services/api'; +import { useAddImageToBoardMutation } from 'services/apiSlice'; + +export type ImageUsage = { + isInitialImage: boolean; + isCanvasImage: boolean; + isNodesImage: boolean; + isControlNetImage: boolean; +}; + +type AddImageToBoardContextValue = { + /** + * Whether the move image dialog is open. + */ + isOpen: boolean; + /** + * Closes the move image dialog. + */ + onClose: () => void; + /** + * The image pending movement + */ + image?: ImageDTO; + onClickAddToBoard: (image: ImageDTO) => void; + handleAddToBoard: (boardId: string) => void; +}; + +export const AddImageToBoardContext = + createContext({ + isOpen: false, + onClose: () => undefined, + onClickAddToBoard: () => undefined, + handleAddToBoard: () => undefined, + }); + +type Props = PropsWithChildren; + +export const AddImageToBoardContextProvider = (props: Props) => { + const [imageToMove, setImageToMove] = useState(); + const { isOpen, onOpen, onClose } = useDisclosure(); + + const [addImageToBoard, result] = useAddImageToBoardMutation(); + + // Clean up after deleting or dismissing the modal + const closeAndClearImageToDelete = useCallback(() => { + setImageToMove(undefined); + onClose(); + }, [onClose]); + + const onClickAddToBoard = useCallback( + (image?: ImageDTO) => { + if (!image) { + return; + } + setImageToMove(image); + onOpen(); + }, + [setImageToMove, onOpen] + ); + + const handleAddToBoard = useCallback( + (boardId: string) => { + if (imageToMove) { + addImageToBoard({ + board_id: boardId, + image_name: imageToMove.image_name, + }); + closeAndClearImageToDelete(); + } + }, + [addImageToBoard, closeAndClearImageToDelete, imageToMove] + ); + + return ( + + {props.children} + + ); +}; diff --git a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx b/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx index 8263b48114..d01298944b 100644 --- a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx +++ b/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx @@ -35,25 +35,23 @@ export const selectImageUsage = createSelector( (state: RootState, image_name?: string) => image_name, ], (generation, canvas, nodes, controlNet, image_name) => { - const isInitialImage = generation.initialImage?.image_name === image_name; + const isInitialImage = generation.initialImage?.imageName === image_name; const isCanvasImage = canvas.layerState.objects.some( - (obj) => obj.kind === 'image' && obj.image.image_name === image_name + (obj) => obj.kind === 'image' && obj.imageName === image_name ); const isNodesImage = nodes.nodes.some((node) => { return some( node.data.inputs, - (input) => - input.type === 'image' && input.value?.image_name === image_name + (input) => input.type === 'image' && input.value === image_name ); }); const isControlNetImage = some( controlNet.controlNets, (c) => - c.controlImage?.image_name === image_name || - c.processedControlImage?.image_name === image_name + c.controlImage === image_name || c.processedControlImage === image_name ); const imageUsage: ImageUsage = { diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index e498ecb749..cb18d48301 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist'; -import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist'; import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist'; import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist'; import { omit } from 'lodash-es'; @@ -18,8 +17,6 @@ const serializationDenylist: { gallery: galleryPersistDenylist, generation: generationPersistDenylist, lightbox: lightboxPersistDenylist, - sd1models: modelsPersistDenylist, - sd2models: modelsPersistDenylist, nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index 649b56316d..8f40b0bb59 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -7,8 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialConfigState } from 'features/system/store/configSlice'; -import { sd1InitialPipelineModelsState } from 'features/system/store/models/sd1PipelineModelSlice'; -import { sd2InitialPipelineModelsState } from 'features/system/store/models/sd2PipelineModelSlice'; import { initialSystemState } from 'features/system/store/systemSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialUIState } from 'features/ui/store/uiSlice'; @@ -22,8 +20,6 @@ const initialStates: { gallery: initialGalleryState, generation: initialGenerationState, lightbox: initialLightboxState, - sd1PipelineModels: sd1InitialPipelineModelsState, - sd2PipelineModels: sd2InitialPipelineModelsState, nodes: initialNodesState, postprocessing: initialPostprocessingState, system: initialSystemState, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 8c073e81d6..cb641d00db 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -73,6 +73,15 @@ import { addImageCategoriesChangedListener } from './listeners/imageCategoriesCh import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect'; +import { + addImageAddedToBoardFulfilledListener, + addImageAddedToBoardRejectedListener, +} from './listeners/imageAddedToBoard'; +import { addBoardIdSelectedListener } from './listeners/boardIdSelected'; +import { + addImageRemovedFromBoardFulfilledListener, + addImageRemovedFromBoardRejectedListener, +} from './listeners/imageRemovedFromBoard'; export const listenerMiddleware = createListenerMiddleware(); @@ -92,6 +101,12 @@ export type AppListenerEffect = ListenerEffect< AppDispatch >; +/** + * The RTK listener middleware is a lightweight alternative sagas/observables. + * + * Most side effect logic should live in a listener. + */ + // Image uploaded addImageUploadedFulfilledListener(); addImageUploadedRejectedListener(); @@ -183,3 +198,10 @@ addControlNetAutoProcessListener(); // Update image URLs on connect addUpdateImageUrlsOnConnectListener(); + +// Boards +addImageAddedToBoardFulfilledListener(); +addImageAddedToBoardRejectedListener(); +addImageRemovedFromBoardFulfilledListener(); +addImageRemovedFromBoardRejectedListener(); +addBoardIdSelectedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts new file mode 100644 index 0000000000..eab4389ceb --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts @@ -0,0 +1,99 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { boardIdSelected } from 'features/gallery/store/boardSlice'; +import { selectImagesAll } from 'features/gallery/store/imagesSlice'; +import { IMAGES_PER_PAGE, receivedPageOfImages } from 'services/thunks/image'; +import { api } from 'services/apiSlice'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; + +const moduleLog = log.child({ namespace: 'boards' }); + +export const addBoardIdSelectedListener = () => { + startAppListening({ + actionCreator: boardIdSelected, + effect: (action, { getState, dispatch }) => { + const boardId = action.payload; + + // we need to check if we need to fetch more images + + const state = getState(); + const allImages = selectImagesAll(state); + + if (!boardId) { + // a board was unselected + dispatch(imageSelected(allImages[0]?.image_name)); + return; + } + + const { categories } = state.images; + + const filteredImages = allImages.filter((i) => { + const isInCategory = categories.includes(i.image_category); + const isInSelectedBoard = boardId ? i.board_id === boardId : true; + return isInCategory && isInSelectedBoard; + }); + + // get the board from the cache + const { data: boards } = api.endpoints.listAllBoards.select()(state); + const board = boards?.find((b) => b.board_id === boardId); + + if (!board) { + // can't find the board in cache... + dispatch(imageSelected(allImages[0]?.image_name)); + return; + } + + dispatch(imageSelected(board.cover_image_name)); + + // if we haven't loaded one full page of images from this board, load more + if ( + filteredImages.length < board.image_count && + filteredImages.length < IMAGES_PER_PAGE + ) { + dispatch(receivedPageOfImages({ categories, boardId })); + } + }, + }); +}; + +export const addBoardIdSelected_changeSelectedImage_listener = () => { + startAppListening({ + actionCreator: boardIdSelected, + effect: (action, { getState, dispatch }) => { + const boardId = action.payload; + + const state = getState(); + + // we need to check if we need to fetch more images + + if (!boardId) { + // a board was unselected - we don't need to do anything + return; + } + + const { categories } = state.images; + + const filteredImages = selectImagesAll(state).filter((i) => { + const isInCategory = categories.includes(i.image_category); + const isInSelectedBoard = boardId ? i.board_id === boardId : true; + return isInCategory && isInSelectedBoard; + }); + + // get the board from the cache + const { data: boards } = api.endpoints.listAllBoards.select()(state); + const board = boards?.find((b) => b.board_id === boardId); + if (!board) { + // can't find the board in cache... + return; + } + + // if we haven't loaded one full page of images from this board, load more + if ( + filteredImages.length < board.image_count && + filteredImages.length < IMAGES_PER_PAGE + ) { + dispatch(receivedPageOfImages({ categories, boardId })); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index ce1b515b84..7ff9a5118c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => { [controlNet.processorNode.id]: { ...controlNet.processorNode, is_intermediate: true, - image: pick(controlNet.controlImage, ['image_name']), + image: { image_name: controlNet.controlImage }, }, }, }; @@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => { dispatch( controlNetProcessedImageChanged({ controlNetId, - processedControlImage, + processedControlImage: processedControlImage.image_name, }) ); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard.ts new file mode 100644 index 0000000000..0f404cab68 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard.ts @@ -0,0 +1,40 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { imageMetadataReceived } from 'services/thunks/image'; +import { api } from 'services/apiSlice'; + +const moduleLog = log.child({ namespace: 'boards' }); + +export const addImageAddedToBoardFulfilledListener = () => { + startAppListening({ + matcher: api.endpoints.addImageToBoard.matchFulfilled, + effect: (action, { getState, dispatch }) => { + const { board_id, image_name } = action.meta.arg.originalArgs; + + moduleLog.debug( + { data: { board_id, image_name } }, + 'Image added to board' + ); + + dispatch( + imageMetadataReceived({ + imageName: image_name, + }) + ); + }, + }); +}; + +export const addImageAddedToBoardRejectedListener = () => { + startAppListening({ + matcher: api.endpoints.addImageToBoard.matchRejected, + effect: (action, { getState, dispatch }) => { + const { board_id, image_name } = action.meta.arg.originalArgs; + + moduleLog.debug( + { data: { board_id, image_name } }, + 'Problem adding image to board' + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts index 85d56d3913..8f01b8d7b8 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts @@ -12,12 +12,16 @@ export const addImageCategoriesChangedListener = () => { startAppListening({ actionCreator: imageCategoriesChanged, effect: (action, { getState, dispatch }) => { - const filteredImagesCount = selectFilteredImagesAsArray( - getState() - ).length; + const state = getState(); + const filteredImagesCount = selectFilteredImagesAsArray(state).length; if (!filteredImagesCount) { - dispatch(receivedPageOfImages()); + dispatch( + receivedPageOfImages({ + categories: action.payload, + boardId: state.boards.selectedBoardId, + }) + ); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index 4c0c057242..224aa0d2aa 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -6,15 +6,15 @@ import { clamp } from 'lodash-es'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageRemoved, - selectImagesEntities, selectImagesIds, } from 'features/gallery/store/imagesSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; +import { api } from 'services/apiSlice'; -const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); +const moduleLog = log.child({ namespace: 'image' }); /** * Called when the user requests an image deletion @@ -22,7 +22,7 @@ const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); export const addRequestedImageDeletionListener = () => { startAppListening({ actionCreator: requestedImageDeletion, - effect: (action, { dispatch, getState }) => { + effect: async (action, { dispatch, getState, condition }) => { const { image, imageUsage } = action.payload; const { image_name } = image; @@ -30,9 +30,8 @@ export const addRequestedImageDeletionListener = () => { const state = getState(); const selectedImage = state.gallery.selectedImage; - if (selectedImage && selectedImage.image_name === image_name) { + if (selectedImage === image_name) { const ids = selectImagesIds(state); - const entities = selectImagesEntities(state); const deletedImageIndex = ids.findIndex( (result) => result.toString() === image_name @@ -48,10 +47,8 @@ export const addRequestedImageDeletionListener = () => { const newSelectedImageId = filteredIds[newSelectedImageIndex]; - const newSelectedImage = entities[newSelectedImageId]; - if (newSelectedImageId) { - dispatch(imageSelected(newSelectedImage)); + dispatch(imageSelected(newSelectedImageId as string)); } else { dispatch(imageSelected()); } @@ -79,7 +76,21 @@ export const addRequestedImageDeletionListener = () => { dispatch(imageRemoved(image_name)); // Delete from server - dispatch(imageDeleted({ imageName: image_name })); + const { requestId } = dispatch(imageDeleted({ imageName: image_name })); + + // Wait for successful deletion, then trigger boards to re-fetch + const wasImageDeleted = await condition( + (action): action is ReturnType => + imageDeleted.fulfilled.match(action) && + action.meta.requestId === requestId, + 30000 + ); + + if (wasImageDeleted) { + dispatch( + api.util.invalidateTags([{ type: 'Board', id: image.board_id }]) + ); + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard.ts new file mode 100644 index 0000000000..40847ade3a --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard.ts @@ -0,0 +1,40 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { imageMetadataReceived } from 'services/thunks/image'; +import { api } from 'services/apiSlice'; + +const moduleLog = log.child({ namespace: 'boards' }); + +export const addImageRemovedFromBoardFulfilledListener = () => { + startAppListening({ + matcher: api.endpoints.removeImageFromBoard.matchFulfilled, + effect: (action, { getState, dispatch }) => { + const { board_id, image_name } = action.meta.arg.originalArgs; + + moduleLog.debug( + { data: { board_id, image_name } }, + 'Image added to board' + ); + + dispatch( + imageMetadataReceived({ + imageName: image_name, + }) + ); + }, + }); +}; + +export const addImageRemovedFromBoardRejectedListener = () => { + startAppListening({ + matcher: api.endpoints.removeImageFromBoard.matchRejected, + effect: (action, { getState, dispatch }) => { + const { board_id, image_name } = action.meta.arg.originalArgs; + + moduleLog.debug( + { data: { board_id, image_name } }, + 'Problem adding image to board' + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index 40ed062353..fc44d206c8 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -46,7 +46,12 @@ export const addImageUploadedFulfilledListener = () => { if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') { const { controlNetId } = postUploadAction; - dispatch(controlNetImageChanged({ controlNetId, controlImage: image })); + dispatch( + controlNetImageChanged({ + controlNetId, + controlImage: image.image_name, + }) + ); return; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index d65f8e8ba6..bf54e63836 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,7 +1,6 @@ import { log } from 'app/logging/useLogger'; import { appSocketConnected, socketConnected } from 'services/events/actions'; import { receivedPageOfImages } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { receivedOpenAPISchema } from 'services/thunks/schema'; import { startAppListening } from '../..'; @@ -15,21 +14,17 @@ export const addSocketConnectedEventListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { sd1pipelinemodels, sd2pipelinemodels, nodes, config, images } = - getState(); + const { nodes, config, images } = getState(); const { disabledTabs } = config; if (!images.ids.length) { - dispatch(receivedPageOfImages()); - } - - if (!sd1pipelinemodels.ids.length) { - dispatch(receivedModels({ baseModel: 'sd-1', modelType: 'pipeline' })); - } - - if (!sd2pipelinemodels.ids.length) { - dispatch(receivedModels({ baseModel: 'sd-2', modelType: 'pipeline' })); + dispatch( + receivedPageOfImages({ + categories: ['general'], + isIntermediate: false, + }) + ); } if (!nodes.schema && !disabledTabs.includes('nodes')) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index c9ab894ddb..c204f0bdfb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -9,6 +9,7 @@ import { imageMetadataReceived } from 'services/thunks/image'; import { sessionCanceled } from 'services/thunks/session'; import { isImageOutput } from 'services/types/guards'; import { progressImageSet } from 'features/system/store/systemSlice'; +import { api } from 'services/apiSlice'; const moduleLog = log.child({ namespace: 'socketio' }); const nodeDenylist = ['dataURL_image']; @@ -24,7 +25,8 @@ export const addInvocationCompleteEventListener = () => { const sessionId = action.payload.data.graph_execution_state_id; - const { cancelType, isCancelScheduled } = getState().system; + const { cancelType, isCancelScheduled, boardIdToAddTo } = + getState().system; // Handle scheduled cancelation if (cancelType === 'scheduled' && isCancelScheduled) { @@ -57,6 +59,15 @@ export const addInvocationCompleteEventListener = () => { dispatch(addImageToStagingArea(imageDTO)); } + if (boardIdToAddTo && !imageDTO.is_intermediate) { + dispatch( + api.endpoints.addImageToBoard.initiate({ + board_id: boardIdToAddTo, + image_name, + }) + ); + } + dispatch(progressImageSet(null)); } // pass along the socket event as an application action diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts index 7cb8012848..b9ddcea4c3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts @@ -22,15 +22,15 @@ const selectAllUsedImages = createSelector( selectImagesEntities, ], (generation, canvas, nodes, controlNet, imageEntities) => { - const allUsedImages: ImageDTO[] = []; + const allUsedImages: string[] = []; if (generation.initialImage) { - allUsedImages.push(generation.initialImage); + allUsedImages.push(generation.initialImage.imageName); } canvas.layerState.objects.forEach((obj) => { if (obj.kind === 'image') { - allUsedImages.push(obj.image); + allUsedImages.push(obj.imageName); } }); @@ -53,7 +53,7 @@ const selectAllUsedImages = createSelector( forEach(imageEntities, (image) => { if (image) { - allUsedImages.push(image); + allUsedImages.push(image.image_name); } }); @@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => { `Fetching new image URLs for ${allUsedImages.length} images` ); - allUsedImages.forEach(({ image_name }) => { + allUsedImages.forEach((image_name) => { dispatch( imageUrlsReceived({ imageName: image_name, diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 06aa6d3535..57a97168a3 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -18,6 +18,7 @@ import postprocessingReducer from 'features/parameters/store/postprocessingSlice import systemReducer from 'features/system/store/systemSlice'; // import sessionReducer from 'features/system/store/sessionSlice'; import nodesReducer from 'features/nodes/store/nodesSlice'; +import boardsReducer from 'features/gallery/store/boardSlice'; import configReducer from 'features/system/store/configSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import uiReducer from 'features/ui/store/uiSlice'; @@ -27,22 +28,16 @@ import { listenerMiddleware } from './middleware/listenerMiddleware'; import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { stateSanitizer } from './middleware/devtools/stateSanitizer'; - -// Model Reducers -import sd1PipelineModelReducer from 'features/system/store/models/sd1PipelineModelSlice'; -import sd2PipelineModelReducer from 'features/system/store/models/sd2PipelineModelSlice'; - import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; +import { api } from 'services/apiSlice'; const allReducers = { canvas: canvasReducer, gallery: galleryReducer, generation: generationReducer, lightbox: lightboxReducer, - sd1pipelinemodels: sd1PipelineModelReducer, - sd2pipelinemodels: sd2PipelineModelReducer, nodes: nodesReducer, postprocessing: postprocessingReducer, system: systemReducer, @@ -51,7 +46,9 @@ const allReducers = { hotkeys: hotkeysReducer, images: imagesReducer, controlNet: controlNetReducer, + boards: boardsReducer, // session: sessionReducer, + [api.reducerPath]: api.reducer, }; const rootReducer = combineReducers(allReducers); @@ -68,6 +65,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'system', 'ui', 'controlNet', + // 'boards', // 'hotkeys', // 'config', ]; @@ -87,6 +85,7 @@ export const store = configureStore({ immutableCheck: false, serializableCheck: false, }) + .concat(api.middleware) .concat(dynamicMiddlewares) .prepend(listenerMiddleware.middleware), devTools: { diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx index 669a68c88a..e54b4a8872 100644 --- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx @@ -9,7 +9,7 @@ import { import { useDraggable, useDroppable } from '@dnd-kit/core'; import { useCombinedRefs } from '@dnd-kit/utilities'; import IAIIconButton from 'common/components/IAIIconButton'; -import { IAIImageFallback } from 'common/components/IAIImageFallback'; +import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import { AnimatePresence } from 'framer-motion'; import { ReactElement, SyntheticEvent, useCallback } from 'react'; @@ -53,7 +53,7 @@ const IAIDndImage = (props: IAIDndImageProps) => { isDropDisabled = false, isDragDisabled = false, isUploadDisabled = false, - fallback = , + fallback = , payloadImage, minSize = 24, postUploadAction, diff --git a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx index 3d34fbca9e..03a00d5b1c 100644 --- a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx +++ b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx @@ -1,10 +1,20 @@ -import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react'; +import { + As, + Flex, + FlexProps, + Icon, + IconProps, + Spinner, + SpinnerProps, +} from '@chakra-ui/react'; +import { ReactElement } from 'react'; +import { FaImage } from 'react-icons/fa'; type Props = FlexProps & { spinnerProps?: SpinnerProps; }; -export const IAIImageFallback = (props: Props) => { +export const IAIImageLoadingFallback = (props: Props) => { const { spinnerProps, ...rest } = props; const { sx, ...restFlexProps } = rest; return ( @@ -25,3 +35,35 @@ export const IAIImageFallback = (props: Props) => { ); }; + +type IAINoImageFallbackProps = { + flexProps?: FlexProps; + iconProps?: IconProps; + as?: As; +}; + +export const IAINoImageFallback = (props: IAINoImageFallbackProps) => { + const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} }; + const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} }; + return ( + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx index b8757eff0c..c3132f0285 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx @@ -1,14 +1,21 @@ -import { Image } from 'react-konva'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { Image, Rect } from 'react-konva'; +import { useGetImageDTOQuery } from 'services/apiSlice'; import useImage from 'use-image'; +import { CanvasImage } from '../store/canvasTypes'; type IAICanvasImageProps = { - url: string; - x: number; - y: number; + canvasImage: CanvasImage; }; const IAICanvasImage = (props: IAICanvasImageProps) => { - const { url, x, y } = props; - const [image] = useImage(url, 'anonymous'); + const { width, height, x, y, imageName } = props.canvasImage; + const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken); + const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous'); + + if (!imageDTO) { + return ; + } + return ; }; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx index ea04aa95c8..ec1e87cca7 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx @@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => { {objects.map((obj, i) => { if (isCanvasBaseImage(obj)) { - return ( - - ); + return ; } else if (isCanvasBaseLine(obj)) { const line = ( { return ( {shouldShowStagingImage && currentStagingAreaImage && ( - + )} {shouldShowStagingOutline && ( diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index b7092bf7e0..3e40c1211d 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -203,7 +203,7 @@ export const canvasSlice = createSlice({ y: 0, width: width, height: height, - image: image, + imageName: image.image_name, }, ], }; @@ -325,7 +325,7 @@ export const canvasSlice = createSlice({ kind: 'image', layer: 'base', ...state.layerState.stagingArea.boundingBox, - image, + imageName: image.image_name, }); state.layerState.stagingArea.selectedImageIndex = @@ -865,25 +865,25 @@ export const canvasSlice = createSlice({ state.doesCanvasNeedScaling = true; }); - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; + // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { + // const { image_name, image_url, thumbnail_url } = action.payload; - state.layerState.objects.forEach((object) => { - if (object.kind === 'image') { - if (object.image.image_name === image_name) { - object.image.image_url = image_url; - object.image.thumbnail_url = thumbnail_url; - } - } - }); + // state.layerState.objects.forEach((object) => { + // if (object.kind === 'image') { + // if (object.image.image_name === image_name) { + // object.image.image_url = image_url; + // object.image.thumbnail_url = thumbnail_url; + // } + // } + // }); - state.layerState.stagingArea.images.forEach((stagedImage) => { - if (stagedImage.image.image_name === image_name) { - stagedImage.image.image_url = image_url; - stagedImage.image.thumbnail_url = thumbnail_url; - } - }); - }); + // state.layerState.stagingArea.images.forEach((stagedImage) => { + // if (stagedImage.image.image_name === image_name) { + // stagedImage.image.image_url = image_url; + // stagedImage.image.thumbnail_url = thumbnail_url; + // } + // }); + // }); }, }); diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts index ae78287a7b..9294e10d32 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts @@ -38,7 +38,7 @@ export type CanvasImage = { y: number; width: number; height: number; - image: ImageDTO; + imageName: string; }; export type CanvasMaskLine = { diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index b8d8896dad..217caf9461 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -11,9 +11,11 @@ import IAIDndImage from 'common/components/IAIDndImage'; import { createSelector } from '@reduxjs/toolkit'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { AnimatePresence, motion } from 'framer-motion'; -import { IAIImageFallback } from 'common/components/IAIImageFallback'; +import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; import IAIIconButton from 'common/components/IAIIconButton'; import { FaUndo } from 'react-icons/fa'; +import { useGetImageDTOQuery } from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; const selector = createSelector( controlNetSelector, @@ -31,24 +33,45 @@ type Props = { const ControlNetImagePreview = (props: Props) => { const { imageSx } = props; - const { controlNetId, controlImage, processedControlImage, processorType } = - props.controlNet; + const { + controlNetId, + controlImage: controlImageName, + processedControlImage: processedControlImageName, + processorType, + } = props.controlNet; const dispatch = useAppDispatch(); const { pendingControlImages } = useAppSelector(selector); const [isMouseOverImage, setIsMouseOverImage] = useState(false); + const { + data: controlImage, + isLoading: isLoadingControlImage, + isError: isErrorControlImage, + isSuccess: isSuccessControlImage, + } = useGetImageDTOQuery(controlImageName ?? skipToken); + + const { + data: processedControlImage, + isLoading: isLoadingProcessedControlImage, + isError: isErrorProcessedControlImage, + isSuccess: isSuccessProcessedControlImage, + } = useGetImageDTOQuery(processedControlImageName ?? skipToken); + const handleDrop = useCallback( (droppedImage: ImageDTO) => { - if (controlImage?.image_name === droppedImage.image_name) { + if (controlImageName === droppedImage.image_name) { return; } setIsMouseOverImage(false); dispatch( - controlNetImageChanged({ controlNetId, controlImage: droppedImage }) + controlNetImageChanged({ + controlNetId, + controlImage: droppedImage.image_name, + }) ); }, - [controlImage, controlNetId, dispatch] + [controlImageName, controlNetId, dispatch] ); const handleResetControlImage = useCallback(() => { @@ -150,7 +173,7 @@ const ControlNetImagePreview = (props: Props) => { h: 'full', }} > - + )} {controlImage && ( diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index f1b62cd997..5a54bdcd74 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -39,8 +39,8 @@ export type ControlNetConfig = { weight: number; beginStepPct: number; endStepPct: number; - controlImage: ImageDTO | null; - processedControlImage: ImageDTO | null; + controlImage: string | null; + processedControlImage: string | null; processorType: ControlNetProcessorType; processorNode: RequiredControlNetProcessorNode; shouldAutoConfig: boolean; @@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({ }, controlNetAddedFromImage: ( state, - action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }> + action: PayloadAction<{ controlNetId: string; controlImage: string }> ) => { const { controlNetId, controlImage } = action.payload; state.controlNets[controlNetId] = { @@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({ state, action: PayloadAction<{ controlNetId: string; - controlImage: ImageDTO | null; + controlImage: string | null; }> ) => { const { controlNetId, controlImage } = action.payload; @@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({ state, action: PayloadAction<{ controlNetId: string; - processedControlImage: ImageDTO | null; + processedControlImage: string | null; }> ) => { const { controlNetId, processedControlImage } = action.payload; @@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({ // Preemptively remove the image from the gallery const { imageName } = action.meta.arg; forEach(state.controlNets, (c) => { - if (c.controlImage?.image_name === imageName) { + if (c.controlImage === imageName) { c.controlImage = null; c.processedControlImage = null; } - if (c.processedControlImage?.image_name === imageName) { + if (c.processedControlImage === imageName) { c.processedControlImage = null; } }); }); - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; + // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { + // const { image_name, image_url, thumbnail_url } = action.payload; - forEach(state.controlNets, (c) => { - if (c.controlImage?.image_name === image_name) { - c.controlImage.image_url = image_url; - c.controlImage.thumbnail_url = thumbnail_url; - } - if (c.processedControlImage?.image_name === image_name) { - c.processedControlImage.image_url = image_url; - c.processedControlImage.thumbnail_url = thumbnail_url; - } - }); - }); + // forEach(state.controlNets, (c) => { + // if (c.controlImage?.image_name === image_name) { + // c.controlImage.image_url = image_url; + // c.controlImage.thumbnail_url = thumbnail_url; + // } + // if (c.processedControlImage?.image_name === image_name) { + // c.processedControlImage.image_url = image_url; + // c.processedControlImage.thumbnail_url = thumbnail_url; + // } + // }); + // }); builder.addCase(appSocketInvocationError, (state, action) => { state.pendingControlImages = []; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/AddBoardButton.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/AddBoardButton.tsx new file mode 100644 index 0000000000..632cebcb33 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/AddBoardButton.tsx @@ -0,0 +1,27 @@ +import IAIButton from 'common/components/IAIButton'; +import { useCallback } from 'react'; +import { useCreateBoardMutation } from 'services/apiSlice'; + +const DEFAULT_BOARD_NAME = 'My Board'; + +const AddBoardButton = () => { + const [createBoard, { isLoading }] = useCreateBoardMutation(); + + const handleCreateBoard = useCallback(() => { + createBoard(DEFAULT_BOARD_NAME); + }, [createBoard]); + + return ( + + Add Board + + ); +}; + +export default AddBoardButton; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx new file mode 100644 index 0000000000..e506c88e2d --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx @@ -0,0 +1,93 @@ +import { Flex, Text } from '@chakra-ui/react'; +import { FaImages } from 'react-icons/fa'; +import { boardIdSelected } from '../../store/boardSlice'; +import { useDispatch } from 'react-redux'; +import { IAINoImageFallback } from 'common/components/IAIImageFallback'; +import { AnimatePresence } from 'framer-motion'; +import { SelectedItemOverlay } from '../SelectedItemOverlay'; +import { useCallback } from 'react'; +import { ImageDTO } from 'services/api'; +import { useRemoveImageFromBoardMutation } from 'services/apiSlice'; +import { useDroppable } from '@dnd-kit/core'; +import IAIDropOverlay from 'common/components/IAIDropOverlay'; + +const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => { + const dispatch = useDispatch(); + + const handleAllImagesBoardClick = () => { + dispatch(boardIdSelected()); + }; + + const [removeImageFromBoard, { isLoading }] = + useRemoveImageFromBoardMutation(); + + const handleDrop = useCallback( + (droppedImage: ImageDTO) => { + if (!droppedImage.board_id) { + return; + } + removeImageFromBoard({ + board_id: droppedImage.board_id, + image_name: droppedImage.image_name, + }); + }, + [removeImageFromBoard] + ); + + const { + isOver, + setNodeRef, + active: isDropActive, + } = useDroppable({ + id: `board_droppable_all_images`, + data: { + handleDrop, + }, + }); + + return ( + + + + + {isSelected && } + + + {isDropActive && } + + + + All Images + + + ); +}; + +export default AllImagesBoard; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx new file mode 100644 index 0000000000..738693a278 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx @@ -0,0 +1,134 @@ +import { + Collapse, + Flex, + Grid, + IconButton, + Input, + InputGroup, + InputRightElement, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { + boardsSelector, + setBoardSearchText, +} from 'features/gallery/store/boardSlice'; +import { memo, useState } from 'react'; +import HoverableBoard from './HoverableBoard'; +import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; +import AddBoardButton from './AddBoardButton'; +import AllImagesBoard from './AllImagesBoard'; +import { CloseIcon } from '@chakra-ui/icons'; +import { useListAllBoardsQuery } from 'services/apiSlice'; + +const selector = createSelector( + [boardsSelector], + (boardsState) => { + const { selectedBoardId, searchText } = boardsState; + return { selectedBoardId, searchText }; + }, + defaultSelectorOptions +); + +type Props = { + isOpen: boolean; +}; + +const BoardsList = (props: Props) => { + const { isOpen } = props; + const dispatch = useAppDispatch(); + const { selectedBoardId, searchText } = useAppSelector(selector); + + const { data: boards } = useListAllBoardsQuery(); + + const filteredBoards = searchText + ? boards?.filter((board) => + board.board_name.toLowerCase().includes(searchText.toLowerCase()) + ) + : boards; + + const [searchMode, setSearchMode] = useState(false); + + const handleBoardSearch = (searchTerm: string) => { + setSearchMode(searchTerm.length > 0); + dispatch(setBoardSearchText(searchTerm)); + }; + const clearBoardSearch = () => { + setSearchMode(false); + dispatch(setBoardSearchText('')); + }; + + return ( + + + + + { + handleBoardSearch(e.target.value); + }} + /> + {searchText && searchText.length && ( + + } + /> + + )} + + + + + + {!searchMode && } + {filteredBoards && + filteredBoards.map((board) => ( + + ))} + + + + + ); +}; + +export default memo(BoardsList); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx new file mode 100644 index 0000000000..a2c07e4870 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx @@ -0,0 +1,193 @@ +import { + Badge, + Box, + Editable, + EditableInput, + EditablePreview, + Flex, + Image, + MenuItem, + MenuList, +} from '@chakra-ui/react'; + +import { useAppDispatch } from 'app/store/storeHooks'; +import { memo, useCallback } from 'react'; +import { FaFolder, FaTrash } from 'react-icons/fa'; +import { ContextMenu } from 'chakra-ui-contextmenu'; +import { BoardDTO, ImageDTO } from 'services/api'; +import { IAINoImageFallback } from 'common/components/IAIImageFallback'; +import { boardIdSelected } from 'features/gallery/store/boardSlice'; +import { + useAddImageToBoardMutation, + useDeleteBoardMutation, + useGetImageDTOQuery, + useUpdateBoardMutation, +} from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { useDroppable } from '@dnd-kit/core'; +import { AnimatePresence } from 'framer-motion'; +import IAIDropOverlay from 'common/components/IAIDropOverlay'; +import { SelectedItemOverlay } from '../SelectedItemOverlay'; + +interface HoverableBoardProps { + board: BoardDTO; + isSelected: boolean; +} + +const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => { + const dispatch = useAppDispatch(); + + const { data: coverImage } = useGetImageDTOQuery( + board.cover_image_name ?? skipToken + ); + + const { board_name, board_id } = board; + + const handleSelectBoard = useCallback(() => { + dispatch(boardIdSelected(board_id)); + }, [board_id, dispatch]); + + const [updateBoard, { isLoading: isUpdateBoardLoading }] = + useUpdateBoardMutation(); + + const [deleteBoard, { isLoading: isDeleteBoardLoading }] = + useDeleteBoardMutation(); + + const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] = + useAddImageToBoardMutation(); + + const handleUpdateBoardName = (newBoardName: string) => { + updateBoard({ board_id, changes: { board_name: newBoardName } }); + }; + + const handleDeleteBoard = useCallback(() => { + deleteBoard(board_id); + }, [board_id, deleteBoard]); + + const handleDrop = useCallback( + (droppedImage: ImageDTO) => { + if (droppedImage.board_id === board_id) { + return; + } + addImageToBoard({ board_id, image_name: droppedImage.image_name }); + }, + [addImageToBoard, board_id] + ); + + const { + isOver, + setNodeRef, + active: isDropActive, + } = useDroppable({ + id: `board_droppable_${board_id}`, + data: { + handleDrop, + }, + }); + + return ( + + + menuProps={{ size: 'sm', isLazy: true }} + renderMenu={() => ( + + } + onClickCapture={handleDeleteBoard} + > + Delete Board + + + )} + > + {(ref) => ( + + + {board.cover_image_name && coverImage?.image_url && ( + + )} + {!(board.cover_image_name && coverImage?.image_url) && ( + + )} + + {board.image_count} + + + {isSelected && } + + + {isDropActive && } + + + + + { + handleUpdateBoardName(nextValue); + }} + > + + + + + + )} + + + ); +}); + +HoverableBoard.displayName = 'HoverableBoard'; + +export default HoverableBoard; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx new file mode 100644 index 0000000000..b16bddd6b4 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx @@ -0,0 +1,93 @@ +import { + AlertDialog, + AlertDialogBody, + AlertDialogContent, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogOverlay, + Box, + Flex, + Spinner, + Text, +} from '@chakra-ui/react'; +import IAIButton from 'common/components/IAIButton'; + +import { memo, useContext, useRef, useState } from 'react'; +import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { useListAllBoardsQuery } from 'services/apiSlice'; + +const UpdateImageBoardModal = () => { + // const boards = useSelector(selectBoardsAll); + const { data: boards, isFetching } = useListAllBoardsQuery(); + const { isOpen, onClose, handleAddToBoard, image } = useContext( + AddImageToBoardContext + ); + const [selectedBoard, setSelectedBoard] = useState(); + + const cancelRef = useRef(null); + + const currentBoard = boards?.find( + (board) => board.board_id === image?.board_id + ); + + return ( + + + + + {currentBoard ? 'Move Image to Board' : 'Add Image to Board'} + + + + + + {currentBoard && ( + + Moving this image from{' '} + {currentBoard.board_name} to + + )} + {isFetching ? ( + + ) : ( + setSelectedBoard(v)} + value={selectedBoard} + data={(boards ?? []).map((board) => ({ + label: board.board_name, + value: board.board_id, + }))} + /> + )} + + + + + Cancel + { + if (selectedBoard) { + handleAddToBoard(selectedBoard); + } + }} + ml={3} + > + {currentBoard ? 'Move' : 'Add'} + + + + + + ); +}; + +export default memo(UpdateImageBoardModal); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx index a5eaeb4c71..169a965be0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx @@ -51,9 +51,12 @@ import { useAppToaster } from 'app/components/Toaster'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { DeleteImageContext } from 'app/contexts/DeleteImageContext'; import { DeleteImageButton } from './DeleteImageModal'; +import { selectImagesById } from '../store/imagesSlice'; +import { RootState } from 'app/store/store'; const currentImageButtonsSelector = createSelector( [ + (state: RootState) => state, systemSelector, gallerySelector, postprocessingSelector, @@ -61,7 +64,7 @@ const currentImageButtonsSelector = createSelector( lightboxSelector, activeTabNameSelector, ], - (system, gallery, postprocessing, ui, lightbox, activeTabName) => { + (state, system, gallery, postprocessing, ui, lightbox, activeTabName) => { const { isProcessing, isConnected, @@ -81,6 +84,8 @@ const currentImageButtonsSelector = createSelector( shouldShowProgressInViewer, } = ui; + const imageDTO = selectImagesById(state, gallery.selectedImage ?? ''); + const { selectedImage } = gallery; return { @@ -97,10 +102,10 @@ const currentImageButtonsSelector = createSelector( activeTabName, isLightboxOpen, shouldHidePreview, - image: selectedImage, - seed: selectedImage?.metadata?.seed, - prompt: selectedImage?.metadata?.positive_conditioning, - negativePrompt: selectedImage?.metadata?.negative_conditioning, + image: imageDTO, + seed: imageDTO?.metadata?.seed, + prompt: imageDTO?.metadata?.positive_conditioning, + negativePrompt: imageDTO?.metadata?.negative_conditioning, shouldShowProgressInViewer, }; }, diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx index c591206a27..5426fee3b1 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx @@ -9,12 +9,12 @@ import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; import NextPrevImageButtons from './NextPrevImageButtons'; import { memo, useCallback } from 'react'; import { systemSelector } from 'features/system/store/systemSelectors'; -import { configSelector } from '../../system/store/configSelectors'; -import { useAppToaster } from 'app/components/Toaster'; import { imageSelected } from '../store/gallerySlice'; import IAIDndImage from 'common/components/IAIDndImage'; import { ImageDTO } from 'services/api'; -import { IAIImageFallback } from 'common/components/IAIImageFallback'; +import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; +import { useGetImageDTOQuery } from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; export const imagesSelector = createSelector( [uiSelector, gallerySelector, systemSelector], @@ -29,7 +29,7 @@ export const imagesSelector = createSelector( return { shouldShowImageDetails, shouldHidePreview, - image: selectedImage, + selectedImage, progressImage, shouldShowProgressInViewer, shouldAntialiasProgressImage, @@ -45,11 +45,23 @@ export const imagesSelector = createSelector( const CurrentImagePreview = () => { const { shouldShowImageDetails, - image, + selectedImage, progressImage, shouldShowProgressInViewer, shouldAntialiasProgressImage, } = useAppSelector(imagesSelector); + + // const image = useAppSelector((state: RootState) => + // selectImagesById(state, selectedImage ?? '') + // ); + + const { + data: image, + isLoading, + isError, + isSuccess, + } = useGetImageDTOQuery(selectedImage ?? skipToken); + const dispatch = useAppDispatch(); const handleDrop = useCallback( @@ -57,7 +69,7 @@ const CurrentImagePreview = () => { if (droppedImage.image_name === image?.image_name) { return; } - dispatch(imageSelected(droppedImage)); + dispatch(imageSelected(droppedImage.image_name)); }, [dispatch, image?.image_name] ); @@ -98,14 +110,14 @@ const CurrentImagePreview = () => { }} > } + fallback={} isUploadDisabled={true} /> )} - {shouldShowImageDetails && image && ( + {shouldShowImageDetails && image && selectedImage && ( { )} - {!shouldShowImageDetails && image && ( + {!shouldShowImageDetails && image && selectedImage && ( - prev.image.image_name === next.image.image_name && - prev.isSelected === next.isSelected; - /** * Gallery image component with delete/use all/use seed buttons on hover. */ -const HoverableImage = memo((props: HoverableImageProps) => { +const HoverableImage = (props: HoverableImageProps) => { const dispatch = useAppDispatch(); const { activeTabName, @@ -93,6 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => { const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const { onDelete } = useContext(DeleteImageContext); + const { onClickAddToBoard } = useContext(AddImageToBoardContext); const handleDelete = useCallback(() => { onDelete(image); }, [image, onDelete]); @@ -106,11 +109,13 @@ const HoverableImage = memo((props: HoverableImageProps) => { }, }); + const [removeFromBoard] = useRemoveImageFromBoardMutation(); + const handleMouseOver = () => setIsHovered(true); const handleMouseOut = () => setIsHovered(false); const handleSelectImage = useCallback(() => { - dispatch(imageSelected(image)); + dispatch(imageSelected(image.image_name)); }, [image, dispatch]); // Recall parameters handlers @@ -168,6 +173,17 @@ const HoverableImage = memo((props: HoverableImageProps) => { // dispatch(setIsLightboxOpen(true)); }; + const handleAddToBoard = useCallback(() => { + onClickAddToBoard(image); + }, [image, onClickAddToBoard]); + + const handleRemoveFromBoard = useCallback(() => { + if (!image.board_id) { + return; + } + removeFromBoard({ board_id: image.board_id, image_name: image.image_name }); + }, [image.board_id, image.image_name, removeFromBoard]); + const handleOpenInNewTab = () => { window.open(image.image_url, '_blank'); }; @@ -244,6 +260,17 @@ const HoverableImage = memo((props: HoverableImageProps) => { {t('parameters.sendToUnifiedCanvas')} )} + } onClickCapture={handleAddToBoard}> + {image.board_id ? 'Change Board' : 'Add to Board'} + + {image.board_id && ( + } + onClickCapture={handleRemoveFromBoard} + > + Remove from Board + + )} } @@ -339,8 +366,6 @@ const HoverableImage = memo((props: HoverableImageProps) => { ); -}, memoEqualityCheck); +}; -HoverableImage.displayName = 'HoverableImage'; - -export default HoverableImage; +export default memo(HoverableImage); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index fe8690e379..46f2378ae0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -1,12 +1,15 @@ import { Box, + Button, ButtonGroup, Flex, FlexProps, Grid, Icon, Text, + VStack, forwardRef, + useDisclosure, } from '@chakra-ui/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; @@ -20,6 +23,7 @@ import { setGalleryImageObjectFit, setShouldAutoSwitchToNewImages, setShouldUseSingleGalleryColumn, + setGalleryView, } from 'features/gallery/store/gallerySlice'; import { togglePinGalleryPanel } from 'features/ui/store/uiSlice'; import { useOverlayScrollbars } from 'overlayscrollbars-react'; @@ -53,41 +57,51 @@ import { selectImagesAll, } from '../store/imagesSlice'; import { receivedPageOfImages } from 'services/thunks/image'; +import BoardsList from './Boards/BoardsList'; +import { boardsSelector } from '../store/boardSlice'; +import { ChevronUpIcon } from '@chakra-ui/icons'; +import { useListAllBoardsQuery } from 'services/apiSlice'; -const categorySelector = createSelector( +const itemSelector = createSelector( [(state: RootState) => state], (state) => { - const { images } = state; - const { categories } = images; + const { categories, total: allImagesTotal, isLoading } = state.images; + const { selectedBoardId } = state.boards; const allImages = selectImagesAll(state); - const filteredImages = allImages.filter((i) => - categories.includes(i.image_category) - ); + + const images = allImages.filter((i) => { + const isInCategory = categories.includes(i.image_category); + const isInSelectedBoard = selectedBoardId + ? i.board_id === selectedBoardId + : true; + return isInCategory && isInSelectedBoard; + }); return { - images: filteredImages, - isLoading: images.isLoading, - areMoreImagesAvailable: filteredImages.length < images.total, - categories: images.categories, + images, + allImagesTotal, + isLoading, + categories, + selectedBoardId, }; }, defaultSelectorOptions ); const mainSelector = createSelector( - [gallerySelector, uiSelector], - (gallery, ui) => { + [gallerySelector, uiSelector, boardsSelector], + (gallery, ui, boards) => { const { galleryImageMinimumWidth, galleryImageObjectFit, shouldAutoSwitchToNewImages, shouldUseSingleGalleryColumn, selectedImage, + galleryView, } = gallery; const { shouldPinGallery } = ui; - return { shouldPinGallery, galleryImageMinimumWidth, @@ -95,6 +109,8 @@ const mainSelector = createSelector( shouldAutoSwitchToNewImages, shouldUseSingleGalleryColumn, selectedImage, + galleryView, + selectedBoardId: boards.selectedBoardId, }; }, defaultSelectorOptions @@ -126,21 +142,44 @@ const ImageGalleryContent = () => { shouldAutoSwitchToNewImages, shouldUseSingleGalleryColumn, selectedImage, + galleryView, } = useAppSelector(mainSelector); - const { images, areMoreImagesAvailable, isLoading, categories } = - useAppSelector(categorySelector); + const { images, isLoading, allImagesTotal, categories, selectedBoardId } = + useAppSelector(itemSelector); + + const { selectedBoard } = useListAllBoardsQuery(undefined, { + selectFromResult: ({ data }) => ({ + selectedBoard: data?.find((b) => b.board_id === selectedBoardId), + }), + }); + + const filteredImagesTotal = useMemo( + () => selectedBoard?.image_count ?? allImagesTotal, + [allImagesTotal, selectedBoard?.image_count] + ); + + const areMoreAvailable = useMemo(() => { + return images.length < filteredImagesTotal; + }, [images.length, filteredImagesTotal]); const handleLoadMoreImages = useCallback(() => { - dispatch(receivedPageOfImages()); - }, [dispatch]); + dispatch( + receivedPageOfImages({ + categories, + boardId: selectedBoardId, + }) + ); + }, [categories, dispatch, selectedBoardId]); const handleEndReached = useMemo(() => { - if (areMoreImagesAvailable && !isLoading) { + if (areMoreAvailable && !isLoading) { return handleLoadMoreImages; } return undefined; - }, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]); + }, [areMoreAvailable, handleLoadMoreImages, isLoading]); + + const { isOpen: isBoardListOpen, onToggle } = useDisclosure(); const handleChangeGalleryImageMinimumWidth = (v: number) => { dispatch(setGalleryImageMinimumWidth(v)); @@ -172,46 +211,79 @@ const ImageGalleryContent = () => { const handleClickImagesCategory = useCallback(() => { dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); + dispatch(setGalleryView('images')); }, [dispatch]); const handleClickAssetsCategory = useCallback(() => { dispatch(imageCategoriesChanged(ASSETS_CATEGORIES)); + dispatch(setGalleryView('assets')); }, [dispatch]); return ( - - - - + + + } + /> + } + /> + + } - /> - } - /> - - + variant="ghost" + sx={{ + w: 'full', + justifyContent: 'center', + alignItems: 'center', + px: 2, + _hover: { + bg: 'base.800', + }, + }} + > + + {selectedBoard ? selectedBoard.board_name : 'All Images'} + + + { icon={shouldPinGallery ? : } /> - - - {images.length || areMoreImagesAvailable ? ( + + + + + + {images.length || areMoreAvailable ? ( <> {shouldUseSingleGalleryColumn ? ( @@ -280,14 +355,12 @@ const ImageGalleryContent = () => { data={images} endReached={handleEndReached} scrollerRef={(ref) => setScrollerRef(ref)} - itemContent={(index, image) => ( + itemContent={(index, item) => ( )} @@ -302,13 +375,11 @@ const ImageGalleryContent = () => { List: ListContainer, }} scrollerRef={setScroller} - itemContent={(index, image) => ( + itemContent={(index, item) => ( )} /> @@ -316,12 +387,12 @@ const ImageGalleryContent = () => { - {areMoreImagesAvailable + {areMoreAvailable ? t('gallery.loadMore') : t('gallery.allImagesLoaded')} @@ -350,7 +421,7 @@ const ImageGalleryContent = () => { )} - + ); }; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx index 892516a3cc..e5cb4cf4a8 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx @@ -93,19 +93,11 @@ type ImageMetadataViewerProps = { image: ImageDTO; }; -// TODO: I don't know if this is needed. -const memoEqualityCheck = ( - prev: ImageMetadataViewerProps, - next: ImageMetadataViewerProps -) => prev.image.image_name === next.image.image_name; - -// TODO: Show more interesting information in this component. - /** * Image metadata viewer overlays currently selected image and provides * access to any of its metadata for use in processing. */ -const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { +const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => { const dispatch = useAppDispatch(); const { recallBothPrompts, @@ -333,8 +325,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { ); -}, memoEqualityCheck); +}; -ImageMetadataViewer.displayName = 'ImageMetadataViewer'; - -export default ImageMetadataViewer; +export default memo(ImageMetadataViewer); diff --git a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx index 82e7a0d623..b1f06ad433 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx @@ -42,7 +42,7 @@ export const nextPrevImageButtonsSelector = createSelector( } const currentImageIndex = filteredImageIds.findIndex( - (i) => i === selectedImage.image_name + (i) => i === selectedImage ); const nextImageIndex = clamp( @@ -71,6 +71,8 @@ export const nextPrevImageButtonsSelector = createSelector( !isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1, nextImage, prevImage, + nextImageId, + prevImageId, }; }, { @@ -84,7 +86,7 @@ const NextPrevImageButtons = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { isOnFirstImage, isOnLastImage, nextImage, prevImage } = + const { isOnFirstImage, isOnLastImage, nextImageId, prevImageId } = useAppSelector(nextPrevImageButtonsSelector); const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = @@ -99,19 +101,19 @@ const NextPrevImageButtons = () => { }, []); const handlePrevImage = useCallback(() => { - dispatch(imageSelected(prevImage)); - }, [dispatch, prevImage]); + dispatch(imageSelected(prevImageId)); + }, [dispatch, prevImageId]); const handleNextImage = useCallback(() => { - dispatch(imageSelected(nextImage)); - }, [dispatch, nextImage]); + dispatch(imageSelected(nextImageId)); + }, [dispatch, nextImageId]); useHotkeys( 'left', () => { handlePrevImage(); }, - [prevImage] + [prevImageId] ); useHotkeys( @@ -119,7 +121,7 @@ const NextPrevImageButtons = () => { () => { handleNextImage(); }, - [nextImage] + [nextImageId] ); return ( diff --git a/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx b/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx new file mode 100644 index 0000000000..7038b4b64f --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx @@ -0,0 +1,26 @@ +import { motion } from 'framer-motion'; + +export const SelectedItemOverlay = () => ( + +); diff --git a/invokeai/frontend/web/src/features/gallery/store/boardSelectors.ts b/invokeai/frontend/web/src/features/gallery/store/boardSelectors.ts new file mode 100644 index 0000000000..3dac2b6e50 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/store/boardSelectors.ts @@ -0,0 +1,23 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { selectBoardsAll } from './boardSlice'; + +export const boardSelector = (state: RootState) => state.boards.entities; + +export const searchBoardsSelector = createSelector( + (state: RootState) => state, + (state) => { + const { + boards: { searchText }, + } = state; + + if (!searchText) { + // If no search text provided, return all entities + return selectBoardsAll(state); + } + + return selectBoardsAll(state).filter((i) => + i.board_name.toLowerCase().includes(searchText.toLowerCase()) + ); + } +); diff --git a/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts b/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts new file mode 100644 index 0000000000..8fc9bfa486 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts @@ -0,0 +1,47 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { api } from 'services/apiSlice'; + +type BoardsState = { + searchText: string; + selectedBoardId?: string; + updateBoardModalOpen: boolean; +}; + +export const initialBoardsState: BoardsState = { + updateBoardModalOpen: false, + searchText: '', +}; + +const boardsSlice = createSlice({ + name: 'boards', + initialState: initialBoardsState, + reducers: { + boardIdSelected: (state, action: PayloadAction) => { + state.selectedBoardId = action.payload; + }, + setBoardSearchText: (state, action: PayloadAction) => { + state.searchText = action.payload; + }, + setUpdateBoardModalOpen: (state, action: PayloadAction) => { + state.updateBoardModalOpen = action.payload; + }, + }, + extraReducers: (builder) => { + builder.addMatcher( + api.endpoints.deleteBoard.matchFulfilled, + (state, action) => { + if (action.meta.arg.originalArgs === state.selectedBoardId) { + state.selectedBoardId = undefined; + } + } + ); + }, +}); + +export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } = + boardsSlice.actions; + +export const boardsSelector = (state: RootState) => state.boards; + +export default boardsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 4f250a7c3a..b7fc0809a6 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -1,17 +1,16 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { ImageDTO } from 'services/api'; import { imageUpserted } from './imagesSlice'; -import { imageUrlsReceived } from 'services/thunks/image'; type GalleryImageObjectFitType = 'contain' | 'cover'; export interface GalleryState { - selectedImage?: ImageDTO; + selectedImage?: string; galleryImageMinimumWidth: number; galleryImageObjectFit: GalleryImageObjectFitType; shouldAutoSwitchToNewImages: boolean; shouldUseSingleGalleryColumn: boolean; + galleryView: 'images' | 'assets' | 'boards'; } export const initialGalleryState: GalleryState = { @@ -19,13 +18,14 @@ export const initialGalleryState: GalleryState = { galleryImageObjectFit: 'cover', shouldAutoSwitchToNewImages: true, shouldUseSingleGalleryColumn: false, + galleryView: 'images', }; export const gallerySlice = createSlice({ name: 'gallery', initialState: initialGalleryState, reducers: { - imageSelected: (state, action: PayloadAction) => { + imageSelected: (state, action: PayloadAction) => { state.selectedImage = action.payload; // TODO: if the user selects an image, disable the auto switch? // state.shouldAutoSwitchToNewImages = false; @@ -48,6 +48,12 @@ export const gallerySlice = createSlice({ ) => { state.shouldUseSingleGalleryColumn = action.payload; }, + setGalleryView: ( + state, + action: PayloadAction<'images' | 'assets' | 'boards'> + ) => { + state.galleryView = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(imageUpserted, (state, action) => { @@ -55,17 +61,17 @@ export const gallerySlice = createSlice({ state.shouldAutoSwitchToNewImages && action.payload.image_category === 'general' ) { - state.selectedImage = action.payload; + state.selectedImage = action.payload.image_name; } }); - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; + // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { + // const { image_name, image_url, thumbnail_url } = action.payload; - if (state.selectedImage?.image_name === image_name) { - state.selectedImage.image_url = image_url; - state.selectedImage.thumbnail_url = thumbnail_url; - } - }); + // if (state.selectedImage?.image_name === image_name) { + // state.selectedImage.image_url = image_url; + // state.selectedImage.thumbnail_url = thumbnail_url; + // } + // }); }, }); @@ -75,6 +81,7 @@ export const { setGalleryImageObjectFit, setShouldAutoSwitchToNewImages, setShouldUseSingleGalleryColumn, + setGalleryView, } = gallerySlice.actions; export default gallerySlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts index 9c18380c54..25a3341532 100644 --- a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts @@ -11,7 +11,6 @@ import { dateComparator } from 'common/util/dateComparator'; import { keyBy } from 'lodash-es'; import { imageDeleted, - imageMetadataReceived, imageUrlsReceived, receivedPageOfImages, } from 'services/thunks/image'; @@ -74,11 +73,21 @@ const imagesSlice = createSlice({ }); builder.addCase(receivedPageOfImages.fulfilled, (state, action) => { state.isLoading = false; + const { boardId, categories, imageOrigin, isIntermediate } = + action.meta.arg; + const { items, offset, limit, total } = action.payload; + imagesAdapter.upsertMany(state, items); + + if (!categories?.includes('general') || boardId) { + // need to skip updating the total images count if the images recieved were for a specific board + // TODO: this doesn't work when on the Asset tab/category... + return; + } + state.offset = offset; state.limit = limit; state.total = total; - imagesAdapter.upsertMany(state, items); }); builder.addCase(imageDeleted.pending, (state, action) => { // Image deleted @@ -154,3 +163,16 @@ export const selectFilteredImagesIds = createSelector( .map((i) => i.image_name); } ); + +// export const selectImageById = createSelector( +// (state: RootState, imageId) => state, +// (state) => { +// const { +// images: { categories }, +// } = state; + +// return selectImagesAll(state) +// .filter((i) => categories.includes(i.image_category)) +// .map((i) => i.image_name); +// } +// ); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx index dc4590e6ca..c5a3a1970b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx @@ -11,6 +11,8 @@ import { FieldComponentProps } from './types'; import IAIDndImage from 'common/components/IAIDndImage'; import { ImageDTO } from 'services/api'; import { Flex } from '@chakra-ui/react'; +import { useGetImageDTOQuery } from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; const ImageInputFieldComponent = ( props: FieldComponentProps @@ -19,9 +21,16 @@ const ImageInputFieldComponent = ( const dispatch = useAppDispatch(); + const { + data: image, + isLoading, + isError, + isSuccess, + } = useGetImageDTOQuery(field.value ?? skipToken); + const handleDrop = useCallback( (droppedImage: ImageDTO) => { - if (field.value?.image_name === droppedImage.image_name) { + if (field.value === droppedImage.image_name) { return; } @@ -29,11 +38,11 @@ const ImageInputFieldComponent = ( fieldValueChanged({ nodeId, fieldName: field.name, - value: droppedImage, + value: droppedImage.image_name, }) ); }, - [dispatch, field.name, field.value?.image_name, nodeId] + [dispatch, field.name, field.value, nodeId] ); const handleReset = useCallback(() => { @@ -56,7 +65,7 @@ const ImageInputFieldComponent = ( }} > @@ -16,26 +20,82 @@ const ModelInputFieldComponent = ( const { nodeId, field } = props; const dispatch = useAppDispatch(); + const { t } = useTranslation(); - const { sd1PipelineModelDropDownData, sd2PipelineModelDropdownData } = - useAppSelector(modelSelector); + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); - const handleValueChanged = (e: ChangeEvent) => { - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: e.target.value, - }) - ); - }; + const data = useMemo(() => { + if (!pipelineModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(pipelineModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [pipelineModels]); + + const selectedModel = useMemo( + () => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]], + [pipelineModels?.entities, pipelineModels?.ids, field.value] + ); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && pipelineModels?.ids.includes(field.value)) { + return; + } + + const firstModel = pipelineModels?.ids[0]; + + if (!isString(firstModel)) { + return; + } + + handleValueChanged(firstModel); + }, [field.value, handleValueChanged, pipelineModels?.ids]); return ( - + /> ); }; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 5425d1cfd5..341f0c467b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -101,21 +101,6 @@ const nodesSlice = createSlice({ builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { state.schema = action.payload; }); - - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; - - state.nodes.forEach((node) => { - forEach(node.data.inputs, (input) => { - if (input.type === 'image') { - if (input.value?.image_name === image_name) { - input.value.image_url = image_url; - input.value.thumbnail_url = thumbnail_url; - } - } - }); - }); - }); }, }); diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 5e140b6eef..acad10cf48 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -214,7 +214,7 @@ export type VaeInputFieldValue = FieldValueBase & { export type ImageInputFieldValue = FieldValueBase & { type: 'image'; - value?: ImageDTO; + value?: string; }; export type ModelInputFieldValue = FieldValueBase & { diff --git a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts index dd5a97e2f1..314af85193 100644 --- a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts @@ -65,15 +65,13 @@ export const addControlNetToLinearGraph = ( if (processedControlImage && processorType !== 'none') { // We've already processed the image in the app, so we can just use the processed image - const { image_name } = processedControlImage; controlNetNode.image = { - image_name, + image_name: processedControlImage, }; } else if (controlImage) { // The control image is preprocessed - const { image_name } = controlImage; controlNetNode.image = { - image_name, + image_name: controlImage, }; } else { // Skip ControlNets without an unprocessed image - should never happen if everything is working correctly diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index efaeaddff2..ccdc3e0a27 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -23,6 +23,7 @@ import { } from './constants'; import { set } from 'lodash-es'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = ( id: NOISE, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 785e1d2fdb..9ffe85b3c9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -17,6 +17,7 @@ import { INPAINT_GRAPH, INPAINT, } from './constants'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = ( // We may need to set the inpaint width and height to scale the image const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; + const model = modelIdToPipelineModelField(modelId); + const graph: NonNullableGraph = { id: INPAINT_GRAPH, nodes: { @@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = ( prompt: negativePrompt, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [RANGE_OF_SIZE]: { type: 'range_of_size', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index ca0e56e849..920cb5bf02 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -14,6 +14,7 @@ import { TEXT_TO_LATENTS, } from './constants'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; /** * Builds the Canvas tab's Text to Image graph. @@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = ( steps, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index 1f2c8327e0..8425ac043a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -22,6 +22,7 @@ import { } from './constants'; import { set } from 'lodash-es'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = ( throw new Error('No initial image found in state'); } + const model = modelIdToPipelineModelField(modelId); + // copy-pasted graph from node editor, filled in with state values & friendly node ids const graph: NonNullableGraph = { id: IMAGE_TO_IMAGE_GRAPH, @@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = ( id: NOISE, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', @@ -274,7 +277,7 @@ export const buildLinearImageToImageGraph = ( id: RESIZE, type: 'img_resize', image: { - image_name: initialImage.image_name, + image_name: initialImage.imageName, }, is_intermediate: true, width, @@ -311,7 +314,7 @@ export const buildLinearImageToImageGraph = ( } else { // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly set(graph.nodes[IMAGE_TO_LATENTS], 'image', { - image_name: initialImage.image_name, + image_name: initialImage.imageName, }); // Pass the image's dimensions to the `NOISE` node diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index c179a89504..973acdfb77 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -1,6 +1,10 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; -import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api'; +import { + BaseModelType, + RandomIntInvocation, + RangeOfSizeInvocation, +} from 'services/api'; import { ITERATE, LATENTS_TO_IMAGE, @@ -14,6 +18,7 @@ import { TEXT_TO_LATENTS, } from './constants'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; type TextToImageGraphOverrides = { width: number; @@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = ( shouldRandomizeSeed, } = state.generation; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = ( steps, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 6a700d4813..072b1a53fd 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,9 +1,10 @@ import { Graph } from 'services/api'; import { v4 as uuidv4 } from 'uuid'; -import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es'; +import { cloneDeep, omit, reduce } from 'lodash-es'; import { RootState } from 'app/store/store'; import { InputFieldValue } from 'features/nodes/types/types'; import { AnyInvocation } from 'services/events/types'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; /** * We need to do special handling for some fields @@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => { } } + if (field.type === 'model') { + if (field.value) { + return modelIdToPipelineModelField(field.value); + } + } + return field.value; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index 39e0080d11..7d4469bc41 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -7,7 +7,7 @@ export const NOISE = 'noise'; export const RANDOM_INT = 'rand_int'; export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; -export const MODEL_LOADER = 'model_loader'; +export const MODEL_LOADER = 'pipeline_model_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts new file mode 100644 index 0000000000..bbcd8d9bc6 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts @@ -0,0 +1,18 @@ +import { BaseModelType, PipelineModelField } from 'services/api'; + +/** + * Crudely converts a model id to a pipeline model field + * TODO: Make better + */ +export const modelIdToPipelineModelField = ( + modelId: string +): PipelineModelField => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const field: PipelineModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts index e29b46af70..6ebd014876 100644 --- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts @@ -57,7 +57,7 @@ export const buildImg2ImgNode = ( } imageToImageNode.image = { - image_name: initialImage.image_name, + image_name: initialImage.imageName, }; } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx index fa415074e6..fbfa00e2a1 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx @@ -10,7 +10,9 @@ import { generationSelector } from 'features/parameters/store/generationSelector import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDndImage from 'common/components/IAIDndImage'; import { ImageDTO } from 'services/api'; -import { IAIImageFallback } from 'common/components/IAIImageFallback'; +import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; +import { useGetImageDTOQuery } from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; const selector = createSelector( [generationSelector], @@ -27,14 +29,21 @@ const InitialImagePreview = () => { const { initialImage } = useAppSelector(selector); const dispatch = useAppDispatch(); + const { + data: image, + isLoading, + isError, + isSuccess, + } = useGetImageDTOQuery(initialImage?.imageName ?? skipToken); + const handleDrop = useCallback( (droppedImage: ImageDTO) => { - if (droppedImage.image_name === initialImage?.image_name) { + if (droppedImage.image_name === initialImage?.imageName) { return; } dispatch(initialImageChanged(droppedImage)); }, - [dispatch, initialImage?.image_name] + [dispatch, initialImage] ); const handleReset = useCallback(() => { @@ -53,10 +62,10 @@ const InitialImagePreview = () => { }} > } + fallback={} postUploadAction={{ type: 'SET_INITIAL_IMAGE' }} withResetIcon /> diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 488136d9f8..e7dcbf0d83 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,12 +1,9 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { DEFAULT_SCHEDULER_NAME } from 'app/constants'; -import { ModelLoaderTypes } from 'features/system/components/ModelSelect'; import { configChanged } from 'features/system/store/configSlice'; -import { clamp, sortBy } from 'lodash-es'; +import { clamp } from 'lodash-es'; import { ImageDTO } from 'services/api'; -import { imageUrlsReceived } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { CfgScaleParam, HeightParam, @@ -25,7 +22,7 @@ export interface GenerationState { height: HeightParam; img2imgStrength: StrengthParam; infillMethod: string; - initialImage?: ImageDTO; + initialImage?: { imageName: string; width: number; height: number }; iterations: number; perlin: number; positivePrompt: PositivePromptParam; @@ -50,7 +47,6 @@ export interface GenerationState { horizontalSymmetrySteps: number; verticalSymmetrySteps: number; model: ModelParam; - currentModelType: ModelLoaderTypes; shouldUseSeamless: boolean; seamlessXAxis: boolean; seamlessYAxis: boolean; @@ -85,7 +81,6 @@ export const initialGenerationState: GenerationState = { horizontalSymmetrySteps: 0, verticalSymmetrySteps: 0, model: '', - currentModelType: 'sd1_model_loader', shouldUseSeamless: false, seamlessXAxis: true, seamlessYAxis: true, @@ -215,38 +210,20 @@ export const generationSlice = createSlice({ state.shouldUseNoiseSettings = action.payload; }, initialImageChanged: (state, action: PayloadAction) => { - state.initialImage = action.payload; + const { image_name, width, height } = action.payload; + state.initialImage = { imageName: image_name, width, height }; }, modelSelected: (state, action: PayloadAction) => { state.model = action.payload; }, - setCurrentModelType: (state, action: PayloadAction) => { - state.currentModelType = action.payload; - }, }, extraReducers: (builder) => { - builder.addCase(receivedModels.fulfilled, (state, action) => { - if (!state.model) { - const firstModel = sortBy(action.payload, 'name')[0]; - state.model = firstModel.name; - } - }); - builder.addCase(configChanged, (state, action) => { const defaultModel = action.payload.sd?.defaultModel; if (defaultModel && !state.model) { state.model = defaultModel; } }); - - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; - - if (state.initialImage?.image_name === image_name) { - state.initialImage.image_url = image_url; - state.initialImage.thumbnail_url = thumbnail_url; - } - }); }, }); @@ -283,7 +260,6 @@ export const { setVerticalSymmetrySteps, initialImageChanged, modelSelected, - setCurrentModelType, setShouldUseNoiseSettings, setSeamless, setSeamlessXAxis, diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index 61567d3fb8..48eb309e7d 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -154,3 +154,17 @@ export type StrengthParam = z.infer; */ export const isValidStrength = (val: unknown): val is StrengthParam => zStrength.safeParse(val).success; + +// /** +// * Zod schema for BaseModelType +// */ +// export const zBaseModelType = z.enum(['sd-1', 'sd-2']); +// /** +// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI. +// */ +// export type BaseModelType = z.infer; +// /** +// * Validates/type-guards a value as a base model type +// */ +// export const isValidBaseModelType = (val: unknown): val is BaseModelType => +// zBaseModelType.safeParse(val).success; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 813bd9fb70..43de14d507 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -1,39 +1,58 @@ -import { memo, useCallback, useEffect } from 'react'; +import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; -import { - modelSelected, - setCurrentModelType, -} from 'features/parameters/store/generationSlice'; +import { modelSelected } from 'features/parameters/store/generationSlice'; -import { modelSelector } from '../store/modelSelectors'; +import { forEach, isString } from 'lodash-es'; +import { SelectItem } from '@mantine/core'; +import { RootState } from 'app/store/store'; +import { useListModelsQuery } from 'services/apiSlice'; -export type ModelLoaderTypes = 'sd1_model_loader' | 'sd2_model_loader'; - -const MODEL_LOADER_MAP = { - 'sd-1': 'sd1_model_loader', - 'sd-2': 'sd2_model_loader', +export const MODEL_TYPE_MAP = { + 'sd-1': 'Stable Diffusion 1.x', + 'sd-2': 'Stable Diffusion 2.x', }; const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { - selectedModel, - sd1PipelineModelDropDownData, - sd2PipelineModelDropdownData, - } = useAppSelector(modelSelector); - useEffect(() => { - if (selectedModel) - dispatch( - setCurrentModelType( - MODEL_LOADER_MAP[selectedModel?.base_model] as ModelLoaderTypes - ) - ); - }, [dispatch, selectedModel]); + const selectedModelId = useAppSelector( + (state: RootState) => state.generation.model + ); + + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); + + const data = useMemo(() => { + if (!pipelineModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(pipelineModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: MODEL_TYPE_MAP[model.base_model], + }); + }); + + return data; + }, [pipelineModels]); + + const selectedModel = useMemo( + () => pipelineModels?.entities[selectedModelId], + [pipelineModels?.entities, selectedModelId] + ); const handleChangeModel = useCallback( (v: string | null) => { @@ -45,13 +64,27 @@ const ModelSelect = () => { [dispatch] ); + useEffect(() => { + if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) { + return; + } + + const firstModel = pipelineModels?.ids[0]; + + if (!isString(firstModel)) { + return; + } + + handleChangeModel(firstModel); + }, [handleChangeModel, pipelineModels?.ids, selectedModelId]); + return ( ); diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index 2e0b3234c7..26c11604e1 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -1,6 +1,5 @@ import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { RootState } from 'app/store/store'; - import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; @@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({ export default function SettingsSchedulers() { const dispatch = useAppDispatch(); + const { t } = useTranslation(); const enabledSchedulers = useAppSelector( diff --git a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts index 193420e29c..8ba5731a5b 100644 --- a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts +++ b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts @@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors'; const isApplicationReadySelector = createSelector( [systemSelector, configSelector], (system, config) => { - const { wereModelsReceived, wasSchemaParsed } = system; + const { wasSchemaParsed } = system; const { disabledTabs } = config; return { disabledTabs, - wereModelsReceived, wasSchemaParsed, }; } @@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector( * Checks if the application is ready to be used, i.e. if the initial startup process is finished. */ export const useIsApplicationReady = () => { - const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector( + const { disabledTabs, wasSchemaParsed } = useAppSelector( isApplicationReadySelector ); const isApplicationReady = useMemo(() => { - if (!wereModelsReceived) { - return false; - } - if (!disabledTabs.includes('nodes') && !wasSchemaParsed) { return false; } return true; - }, [disabledTabs, wereModelsReceived, wasSchemaParsed]); + }, [disabledTabs, wasSchemaParsed]); return isApplicationReady; }; diff --git a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts deleted file mode 100644 index b63c6d256c..0000000000 --- a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts +++ /dev/null @@ -1,59 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { IAISelectDataType } from 'common/components/IAIMantineSelect'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; -import { isEqual } from 'lodash-es'; - -import { - selectAllSD1PipelineModels, - selectByIdSD1PipelineModels, -} from './models/sd1PipelineModelSlice'; - -import { - selectAllSD2PipelineModels, - selectByIdSD2PipelineModels, -} from './models/sd2PipelineModelSlice'; - -export const modelSelector = createSelector( - [(state: RootState) => state, generationSelector], - (state, generation) => { - let selectedModel = selectByIdSD1PipelineModels(state, generation.model); - if (selectedModel === undefined) - selectedModel = selectByIdSD2PipelineModels(state, generation.model); - - const sd1PipelineModels = selectAllSD1PipelineModels(state); - const sd2PipelineModels = selectAllSD2PipelineModels(state); - - const allPipelineModels = sd1PipelineModels.concat(sd2PipelineModels); - - const sd1PipelineModelDropDownData = selectAllSD1PipelineModels(state) - .map((m) => ({ - value: m.name, - label: m.name, - group: '1.x Models', - })) - .sort((a, b) => a.label.localeCompare(b.label)); - - const sd2PipelineModelDropdownData = selectAllSD2PipelineModels(state) - .map((m) => ({ - value: m.name, - label: m.name, - group: '2.x Models', - })) - .sort((a, b) => a.label.localeCompare(b.label)); - - return { - selectedModel, - allPipelineModels, - sd1PipelineModels, - sd2PipelineModels, - sd1PipelineModelDropDownData, - sd2PipelineModelDropdownData, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); diff --git a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts deleted file mode 100644 index 417a399cf2..0000000000 --- a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { SD1PipelineModelState } from './models/sd1PipelineModelSlice'; -import { SD2PipelineModelState } from './models/sd2PipelineModelSlice'; - -/** - * Models slice persist denylist - */ -export const modelsPersistDenylist: - | (keyof SD1PipelineModelState)[] - | (keyof SD2PipelineModelState)[] = ['entities', 'ids']; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index c1fdc750de..688f69c1f7 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -20,7 +20,6 @@ import { } from 'services/events/actions'; import { ProgressImage } from 'services/events/types'; import { imageUploaded } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session'; import { makeToast } from '../../../app/components/Toaster'; import { LANGUAGES } from '../components/LanguagePicker'; @@ -93,6 +92,7 @@ export interface SystemState { shouldAntialiasProgressImage: boolean; language: keyof typeof LANGUAGES; isUploading: boolean; + boardIdToAddTo?: string; } export const initialSystemState: SystemState = { @@ -223,6 +223,7 @@ export const systemSlice = createSlice({ */ builder.addCase(appSocketSubscribed, (state, action) => { state.sessionId = action.payload.sessionId; + state.boardIdToAddTo = action.payload.boardId; state.canceledSession = ''; }); @@ -231,6 +232,7 @@ export const systemSlice = createSlice({ */ builder.addCase(appSocketUnsubscribed, (state) => { state.sessionId = null; + state.boardIdToAddTo = undefined; }); /** @@ -374,13 +376,6 @@ export const systemSlice = createSlice({ ); }); - /** - * Received available models from the backend - */ - builder.addCase(receivedModels.fulfilled, (state) => { - state.wereModelsReceived = true; - }); - /** * OpenAPI schema was parsed */ diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 7481a5daad..8ce42494e5 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -8,6 +8,10 @@ export type { OpenAPIConfig } from './core/OpenAPI'; export type { AddInvocation } from './models/AddInvocation'; export type { BaseModelType } from './models/BaseModelType'; +export type { BoardChanges } from './models/BoardChanges'; +export type { BoardDTO } from './models/BoardDTO'; +export type { Body_create_board_image } from './models/Body_create_board_image'; +export type { Body_remove_board_image } from './models/Body_remove_board_image'; export type { Body_upload_image } from './models/Body_upload_image'; export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation'; export type { CkptModelInfo } from './models/CkptModelInfo'; @@ -21,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField'; export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation'; export type { ControlField } from './models/ControlField'; export type { ControlNetInvocation } from './models/ControlNetInvocation'; +export type { ControlNetModelConfig } from './models/ControlNetModelConfig'; +export type { ControlNetModelFormat } from './models/ControlNetModelFormat'; export type { ControlOutput } from './models/ControlOutput'; export type { CreateModelRequest } from './models/CreateModelRequest'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; @@ -63,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation'; export type { InpaintInvocation } from './models/InpaintInvocation'; export type { IntCollectionOutput } from './models/IntCollectionOutput'; export type { IntOutput } from './models/IntOutput'; -export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config'; -export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig'; -export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config'; -export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config'; export type { IterateInvocation } from './models/IterateInvocation'; export type { IterateInvocationOutput } from './models/IterateInvocationOutput'; export type { LatentsField } from './models/LatentsField'; @@ -83,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation'; export type { LoraInfo } from './models/LoraInfo'; export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation'; export type { LoraLoaderOutput } from './models/LoraLoaderOutput'; +export type { LoRAModelConfig } from './models/LoRAModelConfig'; +export type { LoRAModelFormat } from './models/LoRAModelFormat'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskOutput } from './models/MaskOutput'; export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation'; @@ -98,12 +98,15 @@ export type { MultiplyInvocation } from './models/MultiplyInvocation'; export type { NoiseInvocation } from './models/NoiseInvocation'; export type { NoiseOutput } from './models/NoiseOutput'; export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation'; +export type { OffsetPaginatedResults_BoardDTO_ } from './models/OffsetPaginatedResults_BoardDTO_'; export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_'; export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation'; export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_'; export type { ParamFloatInvocation } from './models/ParamFloatInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation'; export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation'; +export type { PipelineModelField } from './models/PipelineModelField'; +export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation'; export type { PromptCollectionOutput } from './models/PromptCollectionOutput'; export type { PromptOutput } from './models/PromptOutput'; export type { RandomIntInvocation } from './models/RandomIntInvocation'; @@ -115,20 +118,28 @@ export type { ResourceOrigin } from './models/ResourceOrigin'; export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation'; export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation'; export type { SchedulerPredictionType } from './models/SchedulerPredictionType'; -export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation'; -export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation'; export type { ShowImageInvocation } from './models/ShowImageInvocation'; +export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig'; +export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig'; +export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat'; +export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig'; +export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig'; +export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat'; export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation'; export type { SubModelType } from './models/SubModelType'; export type { SubtractInvocation } from './models/SubtractInvocation'; export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation'; +export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig'; export type { UNetField } from './models/UNetField'; export type { UpscaleInvocation } from './models/UpscaleInvocation'; export type { VaeField } from './models/VaeField'; +export type { VaeModelConfig } from './models/VaeModelConfig'; +export type { VaeModelFormat } from './models/VaeModelFormat'; export type { VaeRepo } from './models/VaeRepo'; export type { ValidationError } from './models/ValidationError'; export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation'; +export { BoardsService } from './services/BoardsService'; export { ImagesService } from './services/ImagesService'; export { ModelsService } from './services/ModelsService'; export { SessionsService } from './services/SessionsService'; diff --git a/invokeai/frontend/web/src/services/api/models/BoardChanges.ts b/invokeai/frontend/web/src/services/api/models/BoardChanges.ts new file mode 100644 index 0000000000..fb2bfa0cd9 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/BoardChanges.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type BoardChanges = { + /** + * The board's new name. + */ + board_name?: string; + /** + * The name of the board's new cover image. + */ + cover_image_name?: string; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/BoardDTO.ts b/invokeai/frontend/web/src/services/api/models/BoardDTO.ts new file mode 100644 index 0000000000..bbcc6f1dd6 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/BoardDTO.ts @@ -0,0 +1,38 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * Deserialized board record with cover image URL and image count. + */ +export type BoardDTO = { + /** + * The unique ID of the board. + */ + board_id: string; + /** + * The name of the board. + */ + board_name: string; + /** + * The created timestamp of the board. + */ + created_at: string; + /** + * The updated timestamp of the board. + */ + updated_at: string; + /** + * The deleted timestamp of the board. + */ + deleted_at?: string; + /** + * The name of the board's cover image. + */ + cover_image_name?: string; + /** + * The number of images in the board. + */ + image_count: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/Body_create_board_image.ts b/invokeai/frontend/web/src/services/api/models/Body_create_board_image.ts new file mode 100644 index 0000000000..47f8537eaa --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/Body_create_board_image.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type Body_create_board_image = { + /** + * The id of the board to add to + */ + board_id: string; + /** + * The name of the image to add + */ + image_name: string; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/Body_remove_board_image.ts b/invokeai/frontend/web/src/services/api/models/Body_remove_board_image.ts new file mode 100644 index 0000000000..6f5a3652d0 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/Body_remove_board_image.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type Body_remove_board_image = { + /** + * The id of the board + */ + board_id: string; + /** + * The name of the image to remove + */ + image_name: string; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts b/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts index e4f77ba7bf..60e2958f5c 100644 --- a/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts @@ -3,16 +3,16 @@ /* eslint-disable */ import type { BaseModelType } from './BaseModelType'; +import type { ControlNetModelFormat } from './ControlNetModelFormat'; import type { ModelError } from './ModelError'; -import type { ModelType } from './ModelType'; export type ControlNetModelConfig = { name: string; base_model: BaseModelType; - type: ModelType; + type: 'controlnet'; path: string; description?: string; - format: ('checkpoint' | 'diffusers'); - default?: boolean; + model_format: ControlNetModelFormat; error?: ModelError; }; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts b/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts new file mode 100644 index 0000000000..500b3e8f8c --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type ControlNetModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/Graph.ts b/invokeai/frontend/web/src/services/api/models/Graph.ts index e148954f16..5fba3d8311 100644 --- a/invokeai/frontend/web/src/services/api/models/Graph.ts +++ b/invokeai/frontend/web/src/services/api/models/Graph.ts @@ -49,6 +49,7 @@ import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorI import type { ParamFloatInvocation } from './ParamFloatInvocation'; import type { ParamIntInvocation } from './ParamIntInvocation'; import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation'; +import type { PipelineModelLoaderInvocation } from './PipelineModelLoaderInvocation'; import type { RandomIntInvocation } from './RandomIntInvocation'; import type { RandomRangeInvocation } from './RandomRangeInvocation'; import type { RangeInvocation } from './RangeInvocation'; @@ -56,8 +57,6 @@ import type { RangeOfSizeInvocation } from './RangeOfSizeInvocation'; import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation'; import type { RestoreFaceInvocation } from './RestoreFaceInvocation'; import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation'; -import type { SD1ModelLoaderInvocation } from './SD1ModelLoaderInvocation'; -import type { SD2ModelLoaderInvocation } from './SD2ModelLoaderInvocation'; import type { ShowImageInvocation } from './ShowImageInvocation'; import type { StepParamEasingInvocation } from './StepParamEasingInvocation'; import type { SubtractInvocation } from './SubtractInvocation'; @@ -73,7 +72,7 @@ export type Graph = { /** * The nodes in this graph */ - nodes?: Record; + nodes?: Record; /** * The connections between nodes and their fields in this graph */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts index e51488aa9a..1e0ea0648f 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts @@ -7,7 +7,7 @@ import type { ImageMetadata } from './ImageMetadata'; import type { ResourceOrigin } from './ResourceOrigin'; /** - * Deserialized image record, enriched for the frontend with URLs. + * Deserialized image record, enriched for the frontend. */ export type ImageDTO = { /** @@ -66,4 +66,8 @@ export type ImageDTO = { * A limited subset of the image's generation metadata. Retrieve the image's session for full metadata. */ metadata?: ImageMetadata; + /** + * The id of the board the image belongs to, if one exists. + */ + board_id?: string; }; diff --git a/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts b/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts new file mode 100644 index 0000000000..184a266169 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts @@ -0,0 +1,18 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { LoRAModelFormat } from './LoRAModelFormat'; +import type { ModelError } from './ModelError'; + +export type LoRAModelConfig = { + name: string; + base_model: BaseModelType; + type: 'lora'; + path: string; + description?: string; + model_format: LoRAModelFormat; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts b/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts new file mode 100644 index 0000000000..829f8a7c57 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type LoRAModelFormat = 'lycoris' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/ModelsList.ts b/invokeai/frontend/web/src/services/api/models/ModelsList.ts index a2d88d1967..9186db3e29 100644 --- a/invokeai/frontend/web/src/services/api/models/ModelsList.ts +++ b/invokeai/frontend/web/src/services/api/models/ModelsList.ts @@ -2,16 +2,16 @@ /* tslint:disable */ /* eslint-disable */ -import type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './invokeai__backend__model_management__models__controlnet__ControlNetModel__Config'; -import type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './invokeai__backend__model_management__models__lora__LoRAModel__Config'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig'; -import type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config'; -import type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './invokeai__backend__model_management__models__vae__VaeModel__Config'; +import type { ControlNetModelConfig } from './ControlNetModelConfig'; +import type { LoRAModelConfig } from './LoRAModelConfig'; +import type { StableDiffusion1ModelCheckpointConfig } from './StableDiffusion1ModelCheckpointConfig'; +import type { StableDiffusion1ModelDiffusersConfig } from './StableDiffusion1ModelDiffusersConfig'; +import type { StableDiffusion2ModelCheckpointConfig } from './StableDiffusion2ModelCheckpointConfig'; +import type { StableDiffusion2ModelDiffusersConfig } from './StableDiffusion2ModelDiffusersConfig'; +import type { TextualInversionModelConfig } from './TextualInversionModelConfig'; +import type { VaeModelConfig } from './VaeModelConfig'; export type ModelsList = { - models: Record>>; + models: Array<(StableDiffusion1ModelCheckpointConfig | StableDiffusion1ModelDiffusersConfig | VaeModelConfig | LoRAModelConfig | ControlNetModelConfig | TextualInversionModelConfig | StableDiffusion2ModelCheckpointConfig | StableDiffusion2ModelDiffusersConfig)>; }; diff --git a/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_BoardDTO_.ts b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_BoardDTO_.ts new file mode 100644 index 0000000000..2e4734f469 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_BoardDTO_.ts @@ -0,0 +1,28 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BoardDTO } from './BoardDTO'; + +/** + * Offset-paginated results + */ +export type OffsetPaginatedResults_BoardDTO_ = { + /** + * Items + */ + items: Array; + /** + * Offset from which to retrieve items + */ + offset: number; + /** + * Limit of items to get + */ + limit: number; + /** + * Total number of items in result + */ + total: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts b/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts new file mode 100644 index 0000000000..c2f1c07fbf --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts @@ -0,0 +1,20 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; + +/** + * Pipeline model field + */ +export type PipelineModelField = { + /** + * Name of the model + */ + model_name: string; + /** + * Base model + */ + base_model: BaseModelType; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts b/invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts similarity index 52% rename from invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts rename to invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts index 9a8a23077a..b8cdb27acf 100644 --- a/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts @@ -2,10 +2,12 @@ /* tslint:disable */ /* eslint-disable */ +import type { PipelineModelField } from './PipelineModelField'; + /** - * Loading submodels of selected model. + * Loads a pipeline model, outputting its submodels. */ -export type SD1ModelLoaderInvocation = { +export type PipelineModelLoaderInvocation = { /** * The id of this node. Must be unique among all nodes. */ @@ -14,10 +16,10 @@ export type SD1ModelLoaderInvocation = { * Whether or not this node is an intermediate node. */ is_intermediate?: boolean; - type?: 'sd1_model_loader'; + type?: 'pipeline_model_loader'; /** - * Model to load + * The model to load */ - model_name?: string; + model: PipelineModelField; }; diff --git a/invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts b/invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts deleted file mode 100644 index f477c11a8d..0000000000 --- a/invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts +++ /dev/null @@ -1,23 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -/** - * Loading submodels of selected model. - */ -export type SD2ModelLoaderInvocation = { - /** - * The id of this node. Must be unique among all nodes. - */ - id: string; - /** - * Whether or not this node is an intermediate node. - */ - is_intermediate?: boolean; - type?: 'sd2_model_loader'; - /** - * Model to load - */ - model_name?: string; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts index c9708a0b6f..be7077cc53 100644 --- a/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts @@ -4,19 +4,18 @@ import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; -import type { ModelType } from './ModelType'; import type { ModelVariantType } from './ModelVariantType'; export type StableDiffusion1ModelCheckpointConfig = { name: string; base_model: BaseModelType; - type: ModelType; + type: 'pipeline'; path: string; description?: string; - format: 'checkpoint'; - default?: boolean; + model_format: 'checkpoint'; error?: ModelError; vae?: string; config?: string; variant: ModelVariantType; }; + diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts index 4b6f834216..befe014605 100644 --- a/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts @@ -4,18 +4,17 @@ import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; -import type { ModelType } from './ModelType'; import type { ModelVariantType } from './ModelVariantType'; export type StableDiffusion1ModelDiffusersConfig = { name: string; base_model: BaseModelType; - type: ModelType; + type: 'pipeline'; path: string; description?: string; - format: 'diffusers'; - default?: boolean; + model_format: 'diffusers'; error?: ModelError; vae?: string; variant: ModelVariantType; }; + diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts new file mode 100644 index 0000000000..01b50c2fc0 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type StableDiffusion1ModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts index 27b6879703..dadd7cac9b 100644 --- a/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts @@ -4,18 +4,16 @@ import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; -import type { ModelType } from './ModelType'; import type { ModelVariantType } from './ModelVariantType'; import type { SchedulerPredictionType } from './SchedulerPredictionType'; export type StableDiffusion2ModelCheckpointConfig = { name: string; base_model: BaseModelType; - type: ModelType; + type: 'pipeline'; path: string; description?: string; - format: 'checkpoint'; - default?: boolean; + model_format: 'checkpoint'; error?: ModelError; vae?: string; config?: string; @@ -23,3 +21,4 @@ export type StableDiffusion2ModelCheckpointConfig = { prediction_type: SchedulerPredictionType; upcast_attention: boolean; }; + diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts index a2b66d7157..1e4a34c5dc 100644 --- a/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts @@ -4,21 +4,20 @@ import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; -import type { ModelType } from './ModelType'; import type { ModelVariantType } from './ModelVariantType'; import type { SchedulerPredictionType } from './SchedulerPredictionType'; export type StableDiffusion2ModelDiffusersConfig = { name: string; base_model: BaseModelType; - type: ModelType; + type: 'pipeline'; path: string; description?: string; - format: 'diffusers'; - default?: boolean; + model_format: 'diffusers'; error?: ModelError; vae?: string; variant: ModelVariantType; prediction_type: SchedulerPredictionType; upcast_attention: boolean; }; + diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts new file mode 100644 index 0000000000..7e7b895231 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type StableDiffusion2ModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts b/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts index 7abfbec081..97d6aa7ffa 100644 --- a/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts @@ -4,15 +4,14 @@ import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; -import type { ModelType } from './ModelType'; export type TextualInversionModelConfig = { name: string; base_model: BaseModelType; - type: ModelType; + type: 'embedding'; path: string; description?: string; - format: null; - default?: boolean; + model_format: null; error?: ModelError; }; + diff --git a/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts b/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts new file mode 100644 index 0000000000..a73ee0aa32 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts @@ -0,0 +1,18 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { ModelError } from './ModelError'; +import type { VaeModelFormat } from './VaeModelFormat'; + +export type VaeModelConfig = { + name: string; + base_model: BaseModelType; + type: 'vae'; + path: string; + description?: string; + model_format: VaeModelFormat; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts b/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts new file mode 100644 index 0000000000..497f81d16f --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type VaeModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts deleted file mode 100644 index f8decdb341..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__controlnet__ControlNetModel__Config = { - path: string; - description?: string; - format: ('checkpoint' | 'diffusers'); - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts deleted file mode 100644 index 614749a2c5..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__lora__LoRAModel__Config = { - path: string; - description?: string; - format: ('lycoris' | 'diffusers'); - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts deleted file mode 100644 index 6bdcb87dd4..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts +++ /dev/null @@ -1,18 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; -import type { ModelVariantType } from './ModelVariantType'; - -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig = { - path: string; - description?: string; - format: 'checkpoint'; - default?: boolean; - error?: ModelError; - vae?: string; - config?: string; - variant: ModelVariantType; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts deleted file mode 100644 index c88e042178..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts +++ /dev/null @@ -1,17 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; -import type { ModelVariantType } from './ModelVariantType'; - -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig = { - path: string; - description?: string; - format: 'diffusers'; - default?: boolean; - error?: ModelError; - vae?: string; - variant: ModelVariantType; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts deleted file mode 100644 index ec2ae4a845..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts +++ /dev/null @@ -1,21 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; -import type { ModelVariantType } from './ModelVariantType'; -import type { SchedulerPredictionType } from './SchedulerPredictionType'; - -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig = { - path: string; - description?: string; - format: 'checkpoint'; - default?: boolean; - error?: ModelError; - vae?: string; - config?: string; - variant: ModelVariantType; - prediction_type: SchedulerPredictionType; - upcast_attention: boolean; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts deleted file mode 100644 index 67b897d9d9..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts +++ /dev/null @@ -1,20 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; -import type { ModelVariantType } from './ModelVariantType'; -import type { SchedulerPredictionType } from './SchedulerPredictionType'; - -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig = { - path: string; - description?: string; - format: 'diffusers'; - default?: boolean; - error?: ModelError; - vae?: string; - variant: ModelVariantType; - prediction_type: SchedulerPredictionType; - upcast_attention: boolean; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts deleted file mode 100644 index f23d5002e3..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config = { - path: string; - description?: string; - format: null; - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts deleted file mode 100644 index d9314a6063..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__vae__VaeModel__Config = { - path: string; - description?: string; - format: ('checkpoint' | 'diffusers'); - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/services/BoardsService.ts b/invokeai/frontend/web/src/services/api/services/BoardsService.ts new file mode 100644 index 0000000000..236c765cb9 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/services/BoardsService.ts @@ -0,0 +1,247 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { BoardChanges } from '../models/BoardChanges'; +import type { BoardDTO } from '../models/BoardDTO'; +import type { Body_create_board_image } from '../models/Body_create_board_image'; +import type { Body_remove_board_image } from '../models/Body_remove_board_image'; +import type { OffsetPaginatedResults_BoardDTO_ } from '../models/OffsetPaginatedResults_BoardDTO_'; +import type { OffsetPaginatedResults_ImageDTO_ } from '../models/OffsetPaginatedResults_ImageDTO_'; + +import type { CancelablePromise } from '../core/CancelablePromise'; +import { OpenAPI } from '../core/OpenAPI'; +import { request as __request } from '../core/request'; + +export class BoardsService { + + /** + * List Boards + * Gets a list of boards + * @returns any Successful Response + * @throws ApiError + */ + public static listBoards({ + all, + offset, + limit, + }: { + /** + * Whether to list all boards + */ + all?: boolean, + /** + * The page offset + */ + offset?: number, + /** + * The number of boards per page + */ + limit?: number, + }): CancelablePromise<(OffsetPaginatedResults_BoardDTO_ | Array)> { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v1/boards/', + query: { + 'all': all, + 'offset': offset, + 'limit': limit, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Create Board + * Creates a board + * @returns BoardDTO The board was created successfully + * @throws ApiError + */ + public static createBoard({ + boardName, + }: { + /** + * The name of the board to create + */ + boardName: string, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v1/boards/', + query: { + 'board_name': boardName, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Get Board + * Gets a board + * @returns BoardDTO Successful Response + * @throws ApiError + */ + public static getBoard({ + boardId, + }: { + /** + * The id of board to get + */ + boardId: string, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v1/boards/{board_id}', + path: { + 'board_id': boardId, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Delete Board + * Deletes a board + * @returns any Successful Response + * @throws ApiError + */ + public static deleteBoard({ + boardId, + }: { + /** + * The id of board to delete + */ + boardId: string, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'DELETE', + url: '/api/v1/boards/{board_id}', + path: { + 'board_id': boardId, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Update Board + * Updates a board + * @returns BoardDTO The board was updated successfully + * @throws ApiError + */ + public static updateBoard({ + boardId, + requestBody, + }: { + /** + * The id of board to update + */ + boardId: string, + requestBody: BoardChanges, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'PATCH', + url: '/api/v1/boards/{board_id}', + path: { + 'board_id': boardId, + }, + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Create Board Image + * Creates a board_image + * @returns any The image was added to a board successfully + * @throws ApiError + */ + public static createBoardImage({ + requestBody, + }: { + requestBody: Body_create_board_image, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v1/board_images/', + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Remove Board Image + * Deletes a board_image + * @returns any The image was removed from the board successfully + * @throws ApiError + */ + public static removeBoardImage({ + requestBody, + }: { + requestBody: Body_remove_board_image, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'DELETE', + url: '/api/v1/board_images/', + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * List Board Images + * Gets a list of images for a board + * @returns OffsetPaginatedResults_ImageDTO_ Successful Response + * @throws ApiError + */ + public static listBoardImages({ + boardId, + offset, + limit = 10, + }: { + /** + * The id of the board + */ + boardId: string, + /** + * The page offset + */ + offset?: number, + /** + * The number of boards per page + */ + limit?: number, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v1/board_images/{board_id}', + path: { + 'board_id': boardId, + }, + query: { + 'offset': offset, + 'limit': limit, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + +} diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts index 06065eb1a3..bfdef887a0 100644 --- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts +++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts @@ -22,33 +22,38 @@ export class ImagesService { * @throws ApiError */ public static listImagesWithMetadata({ -imageOrigin, -categories, -isIntermediate, -offset, -limit = 10, -}: { -/** - * The origin of images to list - */ -imageOrigin?: ResourceOrigin, -/** - * The categories of image to include - */ -categories?: Array, -/** - * Whether to list intermediate images - */ -isIntermediate?: boolean, -/** - * The page offset - */ -offset?: number, -/** - * The number of images per page - */ -limit?: number, -}): CancelablePromise { + imageOrigin, + categories, + isIntermediate, + boardId, + offset, + limit = 10, + }: { + /** + * The origin of images to list + */ + imageOrigin?: ResourceOrigin, + /** + * The categories of image to include + */ + categories?: Array, + /** + * Whether to list intermediate images + */ + isIntermediate?: boolean, + /** + * The board id to filter by + */ + boardId?: string, + /** + * The page offset + */ + offset?: number, + /** + * The number of images per page + */ + limit?: number, + }): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/api/v1/images/', @@ -56,6 +61,7 @@ limit?: number, 'image_origin': imageOrigin, 'categories': categories, 'is_intermediate': isIntermediate, + 'board_id': boardId, 'offset': offset, 'limit': limit, }, @@ -72,25 +78,25 @@ limit?: number, * @throws ApiError */ public static uploadImage({ -imageCategory, -isIntermediate, -formData, -sessionId, -}: { -/** - * The category of the image - */ -imageCategory: ImageCategory, -/** - * Whether this is an intermediate image - */ -isIntermediate: boolean, -formData: Body_upload_image, -/** - * The session ID associated with this upload, if any - */ -sessionId?: string, -}): CancelablePromise { + imageCategory, + isIntermediate, + formData, + sessionId, + }: { + /** + * The category of the image + */ + imageCategory: ImageCategory, + /** + * Whether this is an intermediate image + */ + isIntermediate: boolean, + formData: Body_upload_image, + /** + * The session ID associated with this upload, if any + */ + sessionId?: string, + }): CancelablePromise { return __request(OpenAPI, { method: 'POST', url: '/api/v1/images/', @@ -115,13 +121,13 @@ sessionId?: string, * @throws ApiError */ public static getImageFull({ -imageName, -}: { -/** - * The name of full-resolution image file to get - */ -imageName: string, -}): CancelablePromise { + imageName, + }: { + /** + * The name of full-resolution image file to get + */ + imageName: string, + }): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/api/v1/images/{image_name}', @@ -142,13 +148,13 @@ imageName: string, * @throws ApiError */ public static deleteImage({ -imageName, -}: { -/** - * The name of the image to delete - */ -imageName: string, -}): CancelablePromise { + imageName, + }: { + /** + * The name of the image to delete + */ + imageName: string, + }): CancelablePromise { return __request(OpenAPI, { method: 'DELETE', url: '/api/v1/images/{image_name}', @@ -168,15 +174,15 @@ imageName: string, * @throws ApiError */ public static updateImage({ -imageName, -requestBody, -}: { -/** - * The name of the image to update - */ -imageName: string, -requestBody: ImageRecordChanges, -}): CancelablePromise { + imageName, + requestBody, + }: { + /** + * The name of the image to update + */ + imageName: string, + requestBody: ImageRecordChanges, + }): CancelablePromise { return __request(OpenAPI, { method: 'PATCH', url: '/api/v1/images/{image_name}', @@ -198,13 +204,13 @@ requestBody: ImageRecordChanges, * @throws ApiError */ public static getImageMetadata({ -imageName, -}: { -/** - * The name of image to get - */ -imageName: string, -}): CancelablePromise { + imageName, + }: { + /** + * The name of image to get + */ + imageName: string, + }): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/api/v1/images/{image_name}/metadata', @@ -224,13 +230,13 @@ imageName: string, * @throws ApiError */ public static getImageThumbnail({ -imageName, -}: { -/** - * The name of thumbnail image file to get - */ -imageName: string, -}): CancelablePromise { + imageName, + }: { + /** + * The name of thumbnail image file to get + */ + imageName: string, + }): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/api/v1/images/{image_name}/thumbnail', @@ -251,13 +257,13 @@ imageName: string, * @throws ApiError */ public static getImageUrls({ -imageName, -}: { -/** - * The name of the image whose URL to get - */ -imageName: string, -}): CancelablePromise { + imageName, + }: { + /** + * The name of the image whose URL to get + */ + imageName: string, + }): CancelablePromise { return __request(OpenAPI, { method: 'GET', url: '/api/v1/images/{image_name}/urls', diff --git a/invokeai/frontend/web/src/services/api/services/SessionsService.ts b/invokeai/frontend/web/src/services/api/services/SessionsService.ts index 2e4a83b25f..51a36caad1 100644 --- a/invokeai/frontend/web/src/services/api/services/SessionsService.ts +++ b/invokeai/frontend/web/src/services/api/services/SessionsService.ts @@ -51,6 +51,7 @@ import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedR import type { ParamFloatInvocation } from '../models/ParamFloatInvocation'; import type { ParamIntInvocation } from '../models/ParamIntInvocation'; import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation'; +import type { PipelineModelLoaderInvocation } from '../models/PipelineModelLoaderInvocation'; import type { RandomIntInvocation } from '../models/RandomIntInvocation'; import type { RandomRangeInvocation } from '../models/RandomRangeInvocation'; import type { RangeInvocation } from '../models/RangeInvocation'; @@ -58,8 +59,6 @@ import type { RangeOfSizeInvocation } from '../models/RangeOfSizeInvocation'; import type { ResizeLatentsInvocation } from '../models/ResizeLatentsInvocation'; import type { RestoreFaceInvocation } from '../models/RestoreFaceInvocation'; import type { ScaleLatentsInvocation } from '../models/ScaleLatentsInvocation'; -import type { SD1ModelLoaderInvocation } from '../models/SD1ModelLoaderInvocation'; -import type { SD2ModelLoaderInvocation } from '../models/SD2ModelLoaderInvocation'; import type { ShowImageInvocation } from '../models/ShowImageInvocation'; import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation'; import type { SubtractInvocation } from '../models/SubtractInvocation'; @@ -175,7 +174,7 @@ export class SessionsService { * The id of the session */ sessionId: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'POST', @@ -212,7 +211,7 @@ export class SessionsService { * The path to the node in the graph */ nodePath: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'PUT', diff --git a/invokeai/frontend/web/src/services/apiSlice.ts b/invokeai/frontend/web/src/services/apiSlice.ts new file mode 100644 index 0000000000..e2d765dd90 --- /dev/null +++ b/invokeai/frontend/web/src/services/apiSlice.ts @@ -0,0 +1,223 @@ +import { + TagDescription, + createApi, + fetchBaseQuery, +} from '@reduxjs/toolkit/query/react'; +import { BoardDTO } from './api/models/BoardDTO'; +import { OffsetPaginatedResults_BoardDTO_ } from './api/models/OffsetPaginatedResults_BoardDTO_'; +import { BoardChanges } from './api/models/BoardChanges'; +import { OffsetPaginatedResults_ImageDTO_ } from './api/models/OffsetPaginatedResults_ImageDTO_'; +import { ImageDTO } from './api/models/ImageDTO'; +import { + FullTagDescription, + TagTypesFrom, + TagTypesFromApi, +} from '@reduxjs/toolkit/dist/query/endpointDefinitions'; +import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; +import { BaseModelType } from './api/models/BaseModelType'; +import { ModelType } from './api/models/ModelType'; +import { ModelsList } from './api/models/ModelsList'; +import { keyBy } from 'lodash-es'; + +type ListBoardsArg = { offset: number; limit: number }; +type UpdateBoardArg = { board_id: string; changes: BoardChanges }; +type AddImageToBoardArg = { board_id: string; image_name: string }; +type RemoveImageFromBoardArg = { board_id: string; image_name: string }; +type ListBoardImagesArg = { board_id: string; offset: number; limit: number }; +type ListModelsArg = { base_model?: BaseModelType; model_type?: ModelType }; + +type ModelConfig = ModelsList['models'][number]; + +const tagTypes = ['Board', 'Image', 'Model']; +type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>; + +const LIST = 'LIST'; + +const modelsAdapter = createEntityAdapter({ + selectId: (model) => getModelId(model), + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); + +const getModelId = ({ base_model, type, name }: ModelConfig) => + `${base_model}/${type}/${name}`; + +export const api = createApi({ + baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }), + reducerPath: 'api', + tagTypes, + endpoints: (build) => ({ + /** + * Models Queries + */ + + listModels: build.query, ListModelsArg>({ + query: (arg) => ({ url: 'models/', params: arg }), + providesTags: (result, error, arg) => { + // any list of boards + const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST }]; + + if (result) { + // and individual tags for each board + tags.push( + ...result.ids.map((id) => ({ + type: 'Model' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: (response: ModelsList, meta, arg) => { + return modelsAdapter.addMany( + modelsAdapter.getInitialState(), + keyBy(response.models, getModelId) + ); + }, + }), + /** + * Boards Queries + */ + listBoards: build.query({ + query: (arg) => ({ url: 'boards/', params: arg }), + providesTags: (result, error, arg) => { + // any list of boards + const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }]; + + if (result) { + // and individual tags for each board + tags.push( + ...result.items.map(({ board_id }) => ({ + type: 'Board' as const, + id: board_id, + })) + ); + } + + return tags; + }, + }), + + listAllBoards: build.query, void>({ + query: () => ({ + url: 'boards/', + params: { all: true }, + }), + providesTags: (result, error, arg) => { + // any list of boards + const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }]; + + if (result) { + // and individual tags for each board + tags.push( + ...result.map(({ board_id }) => ({ + type: 'Board' as const, + id: board_id, + })) + ); + } + + return tags; + }, + }), + + /** + * Boards Mutations + */ + + createBoard: build.mutation({ + query: (board_name) => ({ + url: `boards/`, + method: 'POST', + params: { board_name }, + }), + invalidatesTags: [{ id: 'Board', type: LIST }], + }), + + updateBoard: build.mutation({ + query: ({ board_id, changes }) => ({ + url: `boards/${board_id}`, + method: 'PATCH', + body: changes, + }), + invalidatesTags: (result, error, arg) => [ + { type: 'Board', id: arg.board_id }, + ], + }), + + deleteBoard: build.mutation({ + query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }), + invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }], + }), + + /** + * Board Images Queries + */ + + listBoardImages: build.query< + OffsetPaginatedResults_ImageDTO_, + ListBoardImagesArg + >({ + query: ({ board_id, offset, limit }) => ({ + url: `board_images/${board_id}`, + method: 'DELETE', + body: { offset, limit }, + }), + }), + + /** + * Board Images Mutations + */ + + addImageToBoard: build.mutation({ + query: ({ board_id, image_name }) => ({ + url: `board_images/`, + method: 'POST', + body: { board_id, image_name }, + }), + invalidatesTags: (result, error, arg) => [ + { type: 'Board', id: arg.board_id }, + { type: 'Image', id: arg.image_name }, + ], + }), + + removeImageFromBoard: build.mutation({ + query: ({ board_id, image_name }) => ({ + url: `board_images/`, + method: 'DELETE', + body: { board_id, image_name }, + }), + invalidatesTags: (result, error, arg) => [ + { type: 'Board', id: arg.board_id }, + { type: 'Image', id: arg.image_name }, + ], + }), + + /** + * Image Queries + */ + getImageDTO: build.query({ + query: (image_name) => ({ url: `images/${image_name}/metadata` }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [{ type: 'Image', id: arg }]; + if (result?.board_id) { + tags.push({ type: 'Board', id: result.board_id }); + } + return tags; + }, + }), + }), +}); + +export const { + useListBoardsQuery, + useListAllBoardsQuery, + useCreateBoardMutation, + useUpdateBoardMutation, + useDeleteBoardMutation, + useAddImageToBoardMutation, + useRemoveImageFromBoardMutation, + useListBoardImagesQuery, + useGetImageDTOQuery, + useListModelsQuery, +} = api; diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index 5832cb24b1..ed154b9cd8 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -53,14 +53,14 @@ export const appSocketDisconnected = createAction( * Do not use. Only for use in middleware. */ export const socketSubscribed = createAction< - BaseSocketPayload & { sessionId: string } + BaseSocketPayload & { sessionId: string; boardId: string | undefined } >('socket/socketSubscribed'); /** * App-level Socket.IO Subscribed */ export const appSocketSubscribed = createAction< - BaseSocketPayload & { sessionId: string } + BaseSocketPayload & { sessionId: string; boardId: string | undefined } >('socket/appSocketSubscribed'); /** diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts index f1eb844f2c..5b427b1690 100644 --- a/invokeai/frontend/web/src/services/events/middleware.ts +++ b/invokeai/frontend/web/src/services/events/middleware.ts @@ -85,6 +85,7 @@ export const socketMiddleware = () => { socketSubscribed({ sessionId: sessionId, timestamp: getTimestamp(), + boardId: getState().boards.selectedBoardId, }) ); } diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index 2c4cba510a..62b5864185 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -44,6 +44,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => { socketSubscribed({ sessionId, timestamp: getTimestamp(), + boardId: getState().boards.selectedBoardId, }) ); } diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index a0725bf235..fe198cf6f9 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -1,5 +1,6 @@ import { createAppAsyncThunk } from 'app/store/storeUtils'; import { selectImagesAll } from 'features/gallery/store/imagesSlice'; +import { size } from 'lodash-es'; import { ImagesService } from 'services/api'; type imageUrlsReceivedArg = Parameters< @@ -121,25 +122,61 @@ type ImagesListedArg = Parameters< export const IMAGES_PER_PAGE = 20; +const DEFAULT_IMAGES_LISTED_ARG = { + isIntermediate: false, + limit: IMAGES_PER_PAGE, +}; + /** * `ImagesService.listImagesWithMetadata()` thunk */ export const receivedPageOfImages = createAppAsyncThunk( 'api/receivedPageOfImages', - async (_, { getState }) => { + async (arg: ImagesListedArg, { getState }) => { const state = getState(); const { categories } = state.images; + const { selectedBoardId } = state.boards; - const totalImagesInFilter = selectImagesAll(state).filter((i) => - categories.includes(i.image_category) - ).length; - - const response = await ImagesService.listImagesWithMetadata({ - categories, - isIntermediate: false, - offset: totalImagesInFilter, - limit: IMAGES_PER_PAGE, + const images = selectImagesAll(state).filter((i) => { + const isInCategory = categories.includes(i.image_category); + const isInSelectedBoard = selectedBoardId + ? i.board_id === selectedBoardId + : true; + return isInCategory && isInSelectedBoard; }); + + let queryArg: ReceivedImagesArg = {}; + + if (size(arg)) { + queryArg = { + ...DEFAULT_IMAGES_LISTED_ARG, + offset: images.length, + ...arg, + }; + } else { + queryArg = { + ...DEFAULT_IMAGES_LISTED_ARG, + categories, + offset: images.length, + }; + } + + const response = await ImagesService.listImagesWithMetadata(queryArg); + return response; + } +); + +type ReceivedImagesArg = Parameters< + (typeof ImagesService)['listImagesWithMetadata'] +>[0]; + +/** + * `ImagesService.listImagesWithMetadata()` thunk + */ +export const receivedImages = createAppAsyncThunk( + 'api/receivedImages', + async (arg: ReceivedImagesArg, { getState }) => { + const response = await ImagesService.listImagesWithMetadata(arg); return response; } ); diff --git a/invokeai/frontend/web/src/services/thunks/model.ts b/invokeai/frontend/web/src/services/thunks/model.ts deleted file mode 100644 index 619aa4b7b2..0000000000 --- a/invokeai/frontend/web/src/services/thunks/model.ts +++ /dev/null @@ -1,58 +0,0 @@ -import { log } from 'app/logging/useLogger'; -import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { SD1PipelineModel } from 'features/system/store/models/sd1PipelineModelSlice'; -import { SD2PipelineModel } from 'features/system/store/models/sd2PipelineModelSlice'; -import { reduce, size } from 'lodash-es'; -import { BaseModelType, ModelType, ModelsService } from 'services/api'; - -const models = log.child({ namespace: 'model' }); - -export const IMAGES_PER_PAGE = 20; - -type receivedModelsArg = { - baseModel: BaseModelType | undefined; - modelType: ModelType | undefined; -}; - -export const receivedModels = createAppAsyncThunk( - 'models/receivedModels', - async (arg: receivedModelsArg) => { - const response = await ModelsService.listModels(arg); - - let deserializedModels = {}; - - if (arg.baseModel === undefined) return response.models; - if (arg.modelType === undefined) return response.models; - - if (arg.baseModel === 'sd-1') { - deserializedModels = reduce( - response.models[arg.baseModel][arg.modelType], - (modelsAccumulator, model, modelName) => { - modelsAccumulator[modelName] = { ...model, name: modelName }; - return modelsAccumulator; - }, - {} as Record - ); - } - - if (arg.baseModel === 'sd-2') { - deserializedModels = reduce( - response.models[arg.baseModel][arg.modelType], - (modelsAccumulator, model, modelName) => { - modelsAccumulator[modelName] = { ...model, name: modelName }; - return modelsAccumulator; - }, - {} as Record - ); - } - - models.info( - { response }, - `Received ${size(response.models[arg.baseModel][arg.modelType])} ${[ - arg.baseModel, - ]} models` - ); - - return deserializedModels; - } -); diff --git a/invokeai/frontend/web/src/services/types/guards.ts b/invokeai/frontend/web/src/services/types/guards.ts index 334c04e6ed..7ac0d95e6a 100644 --- a/invokeai/frontend/web/src/services/types/guards.ts +++ b/invokeai/frontend/web/src/services/types/guards.ts @@ -11,6 +11,7 @@ import { LatentsOutput, ResourceOrigin, ImageDTO, + BoardDTO, } from 'services/api'; export const isImageDTO = (obj: unknown): obj is ImageDTO => { @@ -29,6 +30,16 @@ export const isImageDTO = (obj: unknown): obj is ImageDTO => { ); }; +export const isBoardDTO = (obj: unknown): obj is BoardDTO => { + return ( + isObject(obj) && + 'board_id' in obj && + isString(obj?.board_id) && + 'board_name' in obj && + isString(obj?.board_name) + ); +}; + export const isImageOutput = ( output: GraphExecutionState['results'][string] ): output is ImageOutput => output.type === 'image_output'; diff --git a/tests/test_config.py b/tests/test_config.py index 9317a794c5..cea4991d12 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,8 +8,6 @@ from pathlib import Path os.environ['INVOKEAI_ROOT']='/tmp' from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.invocations.generate import TextToImageInvocation - init1 = OmegaConf.create( ''' @@ -37,13 +35,13 @@ def test_use_init(): # sys.argv respectively. conf1 = InvokeAIAppConfig.get_config() assert conf1 - conf1.parse_args(conf=init1) + conf1.parse_args(conf=init1,argv=[]) assert conf1.max_loaded_models==5 assert not conf1.nsfw_checker conf2 = InvokeAIAppConfig.get_config() assert conf2 - conf2.parse_args(conf=init2) + conf2.parse_args(conf=init2,argv=[]) assert conf2.nsfw_checker assert conf2.max_loaded_models==2 assert not hasattr(conf2,'invalid_attribute') @@ -67,7 +65,7 @@ def test_env_override(): # environment variables should be case insensitive os.environ['InvokeAI_Max_Loaded_Models'] = '15' conf = InvokeAIAppConfig() - conf.parse_args(conf=init1) + conf.parse_args(conf=init1,argv=[]) assert conf.max_loaded_models == 15 conf = InvokeAIAppConfig()