diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 5599d569d5..efeb778922 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -2,8 +2,17 @@ from logging import Logger import os +from invokeai.app.services.board_image_record_storage import ( + SqliteBoardImageRecordStorage, +) +from invokeai.app.services.board_images import ( + BoardImagesService, + BoardImagesServiceDependencies, +) +from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage +from invokeai.app.services.boards import BoardService, BoardServiceDependencies from invokeai.app.services.image_record_storage import SqliteImageRecordStorage -from invokeai.app.services.images import ImageService +from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService @@ -57,7 +66,7 @@ class ApiDependencies: # TODO: build a file/path manager? db_location = config.db_path - db_location.parent.mkdir(parents=True,exist_ok=True) + db_location.parent.mkdir(parents=True, exist_ok=True) graph_execution_manager = SqliteItemStorage[GraphExecutionState]( filename=db_location, table_name="graph_executions" @@ -72,14 +81,40 @@ class ApiDependencies: DiskLatentsStorage(f"{output_folder}/latents") ) + board_record_storage = SqliteBoardRecordStorage(db_location) + board_image_record_storage = SqliteBoardImageRecordStorage(db_location) + + boards = BoardService( + services=BoardServiceDependencies( + board_image_record_storage=board_image_record_storage, + board_record_storage=board_record_storage, + image_record_storage=image_record_storage, + url=urls, + logger=logger, + ) + ) + + board_images = BoardImagesService( + services=BoardImagesServiceDependencies( + board_image_record_storage=board_image_record_storage, + board_record_storage=board_record_storage, + image_record_storage=image_record_storage, + url=urls, + logger=logger, + ) + ) + images = ImageService( - image_record_storage=image_record_storage, - image_file_storage=image_file_storage, - metadata=metadata, - url=urls, - logger=logger, - names=names, - graph_execution_manager=graph_execution_manager, + services=ImageServiceDependencies( + board_image_record_storage=board_image_record_storage, + image_record_storage=image_record_storage, + image_file_storage=image_file_storage, + metadata=metadata, + url=urls, + logger=logger, + names=names, + graph_execution_manager=graph_execution_manager, + ) ) services = InvocationServices( @@ -87,6 +122,8 @@ class ApiDependencies: events=events, latents=latents, images=images, + boards=boards, + board_images=board_images, queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph]( filename=db_location, table_name="graphs" diff --git a/invokeai/app/api/routers/board_images.py b/invokeai/app/api/routers/board_images.py new file mode 100644 index 0000000000..b206ab500d --- /dev/null +++ b/invokeai/app/api/routers/board_images.py @@ -0,0 +1,69 @@ +from fastapi import Body, HTTPException, Path, Query +from fastapi.routing import APIRouter +from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.board_record import BoardDTO +from invokeai.app.services.models.image_record import ImageDTO + +from ..dependencies import ApiDependencies + +board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) + + +@board_images_router.post( + "/", + operation_id="create_board_image", + responses={ + 201: {"description": "The image was added to a board successfully"}, + }, + status_code=201, +) +async def create_board_image( + board_id: str = Body(description="The id of the board to add to"), + image_name: str = Body(description="The name of the image to add"), +): + """Creates a board_image""" + try: + result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name) + return result + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to add to board") + +@board_images_router.delete( + "/", + operation_id="remove_board_image", + responses={ + 201: {"description": "The image was removed from the board successfully"}, + }, + status_code=201, +) +async def remove_board_image( + board_id: str = Body(description="The id of the board"), + image_name: str = Body(description="The name of the image to remove"), +): + """Deletes a board_image""" + try: + result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name) + return result + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to update board") + + + +@board_images_router.get( + "/{board_id}", + operation_id="list_board_images", + response_model=OffsetPaginatedResults[ImageDTO], +) +async def list_board_images( + board_id: str = Path(description="The id of the board"), + offset: int = Query(default=0, description="The page offset"), + limit: int = Query(default=10, description="The number of boards per page"), +) -> OffsetPaginatedResults[ImageDTO]: + """Gets a list of images for a board""" + + results = ApiDependencies.invoker.services.board_images.get_images_for_board( + board_id, + ) + return results + diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py new file mode 100644 index 0000000000..55cd7c8ca2 --- /dev/null +++ b/invokeai/app/api/routers/boards.py @@ -0,0 +1,108 @@ +from typing import Optional, Union +from fastapi import Body, HTTPException, Path, Query +from fastapi.routing import APIRouter +from invokeai.app.services.board_record_storage import BoardChanges +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.board_record import BoardDTO + +from ..dependencies import ApiDependencies + +boards_router = APIRouter(prefix="/v1/boards", tags=["boards"]) + + +@boards_router.post( + "/", + operation_id="create_board", + responses={ + 201: {"description": "The board was created successfully"}, + }, + status_code=201, + response_model=BoardDTO, +) +async def create_board( + board_name: str = Query(description="The name of the board to create"), +) -> BoardDTO: + """Creates a board""" + try: + result = ApiDependencies.invoker.services.boards.create(board_name=board_name) + return result + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to create board") + + +@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO) +async def get_board( + board_id: str = Path(description="The id of board to get"), +) -> BoardDTO: + """Gets a board""" + + try: + result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + return result + except Exception as e: + raise HTTPException(status_code=404, detail="Board not found") + + +@boards_router.patch( + "/{board_id}", + operation_id="update_board", + responses={ + 201: { + "description": "The board was updated successfully", + }, + }, + status_code=201, + response_model=BoardDTO, +) +async def update_board( + board_id: str = Path(description="The id of board to update"), + changes: BoardChanges = Body(description="The changes to apply to the board"), +) -> BoardDTO: + """Updates a board""" + try: + result = ApiDependencies.invoker.services.boards.update( + board_id=board_id, changes=changes + ) + return result + except Exception as e: + raise HTTPException(status_code=500, detail="Failed to update board") + + +@boards_router.delete("/{board_id}", operation_id="delete_board") +async def delete_board( + board_id: str = Path(description="The id of board to delete"), +) -> None: + """Deletes a board""" + + try: + ApiDependencies.invoker.services.boards.delete(board_id=board_id) + except Exception as e: + # TODO: Does this need any exception handling at all? + pass + + +@boards_router.get( + "/", + operation_id="list_boards", + response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]], +) +async def list_boards( + all: Optional[bool] = Query(default=None, description="Whether to list all boards"), + offset: Optional[int] = Query(default=None, description="The page offset"), + limit: Optional[int] = Query( + default=None, description="The number of boards per page" + ), +) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]: + """Gets a list of boards""" + if all: + return ApiDependencies.invoker.services.boards.get_all() + elif offset is not None and limit is not None: + return ApiDependencies.invoker.services.boards.get_many( + offset, + limit, + ) + else: + raise HTTPException( + status_code=400, + detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'", + ) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 11453d97f1..a8c84b81b9 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -221,6 +221,9 @@ async def list_images_with_metadata( is_intermediate: Optional[bool] = Query( default=None, description="Whether to list intermediate images" ), + board_id: Optional[str] = Query( + default=None, description="The board id to filter by" + ), offset: int = Query(default=0, description="The page offset"), limit: int = Query(default=10, description="The number of images per page"), ) -> OffsetPaginatedResults[ImageDTO]: @@ -232,6 +235,7 @@ async def list_images_with_metadata( image_origin, categories, is_intermediate, + board_id, ) return image_dtos diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index f510279f18..50d645eb57 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException from pydantic import BaseModel, Field, parse_obj_as from ..dependencies import ApiDependencies from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management.models import get_all_model_configs -MODEL_CONFIGS = Union[tuple(get_all_model_configs())] +from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS +MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] models_router = APIRouter(prefix="/v1/models", tags=["models"]) @@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel): info: DiffusersModelInfo = Field(description="The converted model info") class ModelsList(BaseModel): - models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend - #models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]] + models: list[MODEL_CONFIGS] @models_router.get( @@ -72,10 +71,10 @@ class ModelsList(BaseModel): responses={200: {"model": ModelsList }}, ) async def list_models( - base_model: BaseModelType = Query( + base_model: Optional[BaseModelType] = Query( default=None, description="Base model" ), - model_type: ModelType = Query( + model_type: Optional[ModelType] = Query( default=None, description="The type of model to get" ), ) -> ModelsList: diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index fa46762d56..e14c58bab7 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config) import invokeai.frontend.web as web_dir from .api.dependencies import ApiDependencies -from .api.routers import sessions, models, images +from .api.routers import sessions, models, images, boards, board_images from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation @@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api") app.include_router(images.images_router, prefix="/api") +app.include_router(boards.boards_router, prefix="/api") + +app.include_router(board_images.board_images_router, prefix="/api") + # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? def custom_openapi(): @@ -116,6 +120,22 @@ def custom_openapi(): invoker_schema["output"] = outputs_ref + from invokeai.backend.model_management.models import get_model_config_enums + for model_config_format_enum in set(get_model_config_enums()): + name = model_config_format_enum.__qualname__ + + if name in openapi_schema["components"]["schemas"]: + # print(f"Config with name {name} already defined") + continue + + # "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"} + openapi_schema["components"]["schemas"][name] = dict( + title=name, + description="An enumeration.", + type="string", + enum=list(v.value for v in model_config_format_enum), + ) + app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 9d77cadf8c..b77aa5dafd 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput): #fmt: on -class SD1ModelLoaderInvocation(BaseInvocation): - """Loading submodels of selected model.""" +class PipelineModelField(BaseModel): + """Pipeline model field""" - type: Literal["sd1_model_loader"] = "sd1_model_loader" + model_name: str = Field(description="Name of the model") + base_model: BaseModelType = Field(description="Base model") - model_name: str = Field(default="", description="Model to load") + +class PipelineModelLoaderInvocation(BaseInvocation): + """Loads a pipeline model, outputting its submodels.""" + + type: Literal["pipeline_model_loader"] = "pipeline_model_loader" + + model: PipelineModelField = Field(description="The model to load") # TODO: precision? # Schema customisation @@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation): "ui": { "tags": ["model", "loader"], "type_hints": { - "model_name": "model" # TODO: rename to model_name? + "model": "model" } }, } def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = BaseModelType.StableDiffusion1 # TODO: + base_model = self.model.base_model + model_name = self.model.model_name + model_type = ModelType.Pipeline # TODO: not found exceptions if not context.services.model_manager.model_exists( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, ): - raise Exception(f"Unkown model name: {self.model_name}!") + raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") """ if not context.services.model_manager.model_exists( @@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation): return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.UNet, ), scheduler=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.TextEncoder, ), loras=[], ), vae=VaeField( vae=ModelInfo( - model_name=self.model_name, + model_name=model_name, base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.Vae, - ), - ) - ) - -# TODO: optimize(less code copy) -class SD2ModelLoaderInvocation(BaseInvocation): - """Loading submodels of selected model.""" - - type: Literal["sd2_model_loader"] = "sd2_model_loader" - - model_name: str = Field(default="", description="Model to load") - # TODO: precision? - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["model", "loader"], - "type_hints": { - "model_name": "model" # TODO: rename to model_name? - } - }, - } - - def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - - base_model = BaseModelType.StableDiffusion2 # TODO: - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - ): - raise Exception(f"Unkown model name: {self.model_name}!") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ - - - return ModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, - submodel=SubModelType.TextEncoder, - ), - loras=[], - ), - vae=VaeField( - vae=ModelInfo( - model_name=self.model_name, - base_model=base_model, - model_type=ModelType.Pipeline, + model_type=model_type, submodel=SubModelType.Vae, ), ) diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py new file mode 100644 index 0000000000..7aff41860c --- /dev/null +++ b/invokeai/app/services/board_image_record_storage.py @@ -0,0 +1,254 @@ +from abc import ABC, abstractmethod +import sqlite3 +import threading +from typing import Union, cast +from invokeai.app.services.board_record_storage import BoardRecord + +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.image_record import ( + ImageRecord, + deserialize_image_record, +) + + +class BoardImageRecordStorageBase(ABC): + """Abstract base class for the one-to-many board-image relationship record storage.""" + + @abstractmethod + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Adds an image to a board.""" + pass + + @abstractmethod + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Removes an image from a board.""" + pass + + @abstractmethod + def get_images_for_board( + self, + board_id: str, + ) -> OffsetPaginatedResults[ImageRecord]: + """Gets images for a board.""" + pass + + @abstractmethod + def get_board_for_image( + self, + image_name: str, + ) -> Union[str, None]: + """Gets an image's board id, if it has one.""" + pass + + @abstractmethod + def get_image_count_for_board( + self, + board_id: str, + ) -> int: + """Gets the number of images for a board.""" + pass + + +class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): + _filename: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.Lock + + def __init__(self, filename: str) -> None: + super().__init__() + self._filename = filename + self._conn = sqlite3.connect(filename, check_same_thread=False) + # Enable row factory to get rows as dictionaries (must be done before making the cursor!) + self._conn.row_factory = sqlite3.Row + self._cursor = self._conn.cursor() + self._lock = threading.Lock() + + try: + self._lock.acquire() + # Enable foreign keys + self._conn.execute("PRAGMA foreign_keys = ON;") + self._create_tables() + self._conn.commit() + finally: + self._lock.release() + + def _create_tables(self) -> None: + """Creates the `board_images` junction table.""" + + # Create the `board_images` junction table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS board_images ( + board_id TEXT NOT NULL, + image_name TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME, + -- enforce one-to-many relationship between boards and images using PK + -- (we can extend this to many-to-many later) + PRIMARY KEY (image_name), + FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE, + FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE + ); + """ + ) + + # Add index for board id + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id); + """ + ) + + # Add index for board id, sorted by created_at + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at); + """ + ) + + # Add trigger for `updated_at`. + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at + AFTER UPDATE + ON board_images FOR EACH ROW + BEGIN + UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE board_id = old.board_id AND image_name = old.image_name; + END; + """ + ) + + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT INTO board_images (board_id, image_name) + VALUES (?, ?) + ON CONFLICT (image_name) DO UPDATE SET board_id = ?; + """, + (board_id, image_name, board_id), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE FROM board_images + WHERE board_id = ? AND image_name = ?; + """, + (board_id, image_name), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def get_images_for_board( + self, + board_id: str, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[ImageRecord]: + # TODO: this isn't paginated yet? + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT images.* + FROM board_images + INNER JOIN images ON board_images.image_name = images.image_name + WHERE board_images.board_id = ? + ORDER BY board_images.updated_at DESC; + """, + (board_id,), + ) + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + images = list(map(lambda r: deserialize_image_record(dict(r)), result)) + + self._cursor.execute( + """--sql + SELECT COUNT(*) FROM images WHERE 1=1; + """ + ) + count = cast(int, self._cursor.fetchone()[0]) + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + return OffsetPaginatedResults( + items=images, offset=offset, limit=limit, total=count + ) + + def get_board_for_image( + self, + image_name: str, + ) -> Union[str, None]: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT board_id + FROM board_images + WHERE image_name = ?; + """, + (image_name,), + ) + result = self._cursor.fetchone() + if result is None: + return None + return cast(str, result[0]) + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def get_image_count_for_board(self, board_id: str) -> int: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT COUNT(*) FROM board_images WHERE board_id = ?; + """, + (board_id,), + ) + count = cast(int, self._cursor.fetchone()[0]) + return count + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() diff --git a/invokeai/app/services/board_images.py b/invokeai/app/services/board_images.py new file mode 100644 index 0000000000..072effbfae --- /dev/null +++ b/invokeai/app/services/board_images.py @@ -0,0 +1,142 @@ +from abc import ABC, abstractmethod +from logging import Logger +from typing import List, Union +from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase +from invokeai.app.services.board_record_storage import ( + BoardRecord, + BoardRecordStorageBase, +) + +from invokeai.app.services.image_record_storage import ( + ImageRecordStorageBase, + OffsetPaginatedResults, +) +from invokeai.app.services.models.board_record import BoardDTO +from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto +from invokeai.app.services.urls import UrlServiceBase + + +class BoardImagesServiceABC(ABC): + """High-level service for board-image relationship management.""" + + @abstractmethod + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Adds an image to a board.""" + pass + + @abstractmethod + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + """Removes an image from a board.""" + pass + + @abstractmethod + def get_images_for_board( + self, + board_id: str, + ) -> OffsetPaginatedResults[ImageDTO]: + """Gets images for a board.""" + pass + + @abstractmethod + def get_board_for_image( + self, + image_name: str, + ) -> Union[str, None]: + """Gets an image's board id, if it has one.""" + pass + + +class BoardImagesServiceDependencies: + """Service dependencies for the BoardImagesService.""" + + board_image_records: BoardImageRecordStorageBase + board_records: BoardRecordStorageBase + image_records: ImageRecordStorageBase + urls: UrlServiceBase + logger: Logger + + def __init__( + self, + board_image_record_storage: BoardImageRecordStorageBase, + image_record_storage: ImageRecordStorageBase, + board_record_storage: BoardRecordStorageBase, + url: UrlServiceBase, + logger: Logger, + ): + self.board_image_records = board_image_record_storage + self.image_records = image_record_storage + self.board_records = board_record_storage + self.urls = url + self.logger = logger + + +class BoardImagesService(BoardImagesServiceABC): + _services: BoardImagesServiceDependencies + + def __init__(self, services: BoardImagesServiceDependencies): + self._services = services + + def add_image_to_board( + self, + board_id: str, + image_name: str, + ) -> None: + self._services.board_image_records.add_image_to_board(board_id, image_name) + + def remove_image_from_board( + self, + board_id: str, + image_name: str, + ) -> None: + self._services.board_image_records.remove_image_from_board(board_id, image_name) + + def get_images_for_board( + self, + board_id: str, + ) -> OffsetPaginatedResults[ImageDTO]: + image_records = self._services.board_image_records.get_images_for_board( + board_id + ) + image_dtos = list( + map( + lambda r: image_record_to_dto( + r, + self._services.urls.get_image_url(r.image_name), + self._services.urls.get_image_url(r.image_name, True), + board_id, + ), + image_records.items, + ) + ) + return OffsetPaginatedResults[ImageDTO]( + items=image_dtos, + offset=image_records.offset, + limit=image_records.limit, + total=image_records.total, + ) + + def get_board_for_image( + self, + image_name: str, + ) -> Union[str, None]: + board_id = self._services.board_image_records.get_board_for_image(image_name) + return board_id + + +def board_record_to_dto( + board_record: BoardRecord, cover_image_name: str | None, image_count: int +) -> BoardDTO: + """Converts a board record to a board DTO.""" + return BoardDTO( + **board_record.dict(exclude={'cover_image_name'}), + cover_image_name=cover_image_name, + image_count=image_count, + ) diff --git a/invokeai/app/services/board_record_storage.py b/invokeai/app/services/board_record_storage.py new file mode 100644 index 0000000000..15ea9cc5a7 --- /dev/null +++ b/invokeai/app/services/board_record_storage.py @@ -0,0 +1,329 @@ +from abc import ABC, abstractmethod +from typing import Optional, cast +import sqlite3 +import threading +from typing import Optional, Union +import uuid +from invokeai.app.services.image_record_storage import OffsetPaginatedResults +from invokeai.app.services.models.board_record import ( + BoardRecord, + deserialize_board_record, +) + +from pydantic import BaseModel, Field, Extra + + +class BoardChanges(BaseModel, extra=Extra.forbid): + board_name: Optional[str] = Field(description="The board's new name.") + cover_image_name: Optional[str] = Field( + description="The name of the board's new cover image." + ) + + +class BoardRecordNotFoundException(Exception): + """Raised when an board record is not found.""" + + def __init__(self, message="Board record not found"): + super().__init__(message) + + +class BoardRecordSaveException(Exception): + """Raised when an board record cannot be saved.""" + + def __init__(self, message="Board record not saved"): + super().__init__(message) + + +class BoardRecordDeleteException(Exception): + """Raised when an board record cannot be deleted.""" + + def __init__(self, message="Board record not deleted"): + super().__init__(message) + + +class BoardRecordStorageBase(ABC): + """Low-level service responsible for interfacing with the board record store.""" + + @abstractmethod + def delete(self, board_id: str) -> None: + """Deletes a board record.""" + pass + + @abstractmethod + def save( + self, + board_name: str, + ) -> BoardRecord: + """Saves a board record.""" + pass + + @abstractmethod + def get( + self, + board_id: str, + ) -> BoardRecord: + """Gets a board record.""" + pass + + @abstractmethod + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardRecord: + """Updates a board record.""" + pass + + @abstractmethod + def get_many( + self, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardRecord]: + """Gets many board records.""" + pass + + @abstractmethod + def get_all( + self, + ) -> list[BoardRecord]: + """Gets all board records.""" + pass + + +class SqliteBoardRecordStorage(BoardRecordStorageBase): + _filename: str + _conn: sqlite3.Connection + _cursor: sqlite3.Cursor + _lock: threading.Lock + + def __init__(self, filename: str) -> None: + super().__init__() + self._filename = filename + self._conn = sqlite3.connect(filename, check_same_thread=False) + # Enable row factory to get rows as dictionaries (must be done before making the cursor!) + self._conn.row_factory = sqlite3.Row + self._cursor = self._conn.cursor() + self._lock = threading.Lock() + + try: + self._lock.acquire() + # Enable foreign keys + self._conn.execute("PRAGMA foreign_keys = ON;") + self._create_tables() + self._conn.commit() + finally: + self._lock.release() + + def _create_tables(self) -> None: + """Creates the `boards` table and `board_images` junction table.""" + + # Create the `boards` table. + self._cursor.execute( + """--sql + CREATE TABLE IF NOT EXISTS boards ( + board_id TEXT NOT NULL PRIMARY KEY, + board_name TEXT NOT NULL, + cover_image_name TEXT, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Soft delete, currently unused + deleted_at DATETIME, + FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL + ); + """ + ) + + self._cursor.execute( + """--sql + CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at); + """ + ) + + # Add trigger for `updated_at`. + self._cursor.execute( + """--sql + CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at + AFTER UPDATE + ON boards FOR EACH ROW + BEGIN + UPDATE boards SET updated_at = current_timestamp + WHERE board_id = old.board_id; + END; + """ + ) + + def delete(self, board_id: str) -> None: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + DELETE FROM boards + WHERE board_id = ?; + """, + (board_id,), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordDeleteException from e + except Exception as e: + self._conn.rollback() + raise BoardRecordDeleteException from e + finally: + self._lock.release() + + def save( + self, + board_name: str, + ) -> BoardRecord: + try: + board_id = str(uuid.uuid4()) + self._lock.acquire() + self._cursor.execute( + """--sql + INSERT OR IGNORE INTO boards (board_id, board_name) + VALUES (?, ?); + """, + (board_id, board_name), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordSaveException from e + finally: + self._lock.release() + return self.get(board_id) + + def get( + self, + board_id: str, + ) -> BoardRecord: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT * + FROM boards + WHERE board_id = ?; + """, + (board_id,), + ) + + result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordNotFoundException from e + finally: + self._lock.release() + if result is None: + raise BoardRecordNotFoundException + return BoardRecord(**dict(result)) + + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardRecord: + try: + self._lock.acquire() + + # Change the name of a board + if changes.board_name is not None: + self._cursor.execute( + f"""--sql + UPDATE boards + SET board_name = ? + WHERE board_id = ?; + """, + (changes.board_name, board_id), + ) + + # Change the cover image of a board + if changes.cover_image_name is not None: + self._cursor.execute( + f"""--sql + UPDATE boards + SET cover_image_name = ? + WHERE board_id = ?; + """, + (changes.cover_image_name, board_id), + ) + + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise BoardRecordSaveException from e + finally: + self._lock.release() + return self.get(board_id) + + def get_many( + self, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardRecord]: + try: + self._lock.acquire() + + # Get all the boards + self._cursor.execute( + """--sql + SELECT * + FROM boards + ORDER BY created_at DESC + LIMIT ? OFFSET ?; + """, + (limit, offset), + ) + + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + boards = list(map(lambda r: deserialize_board_record(dict(r)), result)) + + # Get the total number of boards + self._cursor.execute( + """--sql + SELECT COUNT(*) + FROM boards + WHERE 1=1; + """ + ) + + count = cast(int, self._cursor.fetchone()[0]) + + return OffsetPaginatedResults[BoardRecord]( + items=boards, offset=offset, limit=limit, total=count + ) + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() + + def get_all( + self, + ) -> list[BoardRecord]: + try: + self._lock.acquire() + + # Get all the boards + self._cursor.execute( + """--sql + SELECT * + FROM boards + ORDER BY created_at DESC + """ + ) + + result = cast(list[sqlite3.Row], self._cursor.fetchall()) + boards = list(map(lambda r: deserialize_board_record(dict(r)), result)) + + return boards + + except sqlite3.Error as e: + self._conn.rollback() + raise e + finally: + self._lock.release() diff --git a/invokeai/app/services/boards.py b/invokeai/app/services/boards.py new file mode 100644 index 0000000000..9361322e6c --- /dev/null +++ b/invokeai/app/services/boards.py @@ -0,0 +1,185 @@ +from abc import ABC, abstractmethod + +from logging import Logger +from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase +from invokeai.app.services.board_images import board_record_to_dto + +from invokeai.app.services.board_record_storage import ( + BoardChanges, + BoardRecordStorageBase, +) +from invokeai.app.services.image_record_storage import ( + ImageRecordStorageBase, + OffsetPaginatedResults, +) +from invokeai.app.services.models.board_record import BoardDTO +from invokeai.app.services.urls import UrlServiceBase + + +class BoardServiceABC(ABC): + """High-level service for board management.""" + + @abstractmethod + def create( + self, + board_name: str, + ) -> BoardDTO: + """Creates a board.""" + pass + + @abstractmethod + def get_dto( + self, + board_id: str, + ) -> BoardDTO: + """Gets a board.""" + pass + + @abstractmethod + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardDTO: + """Updates a board.""" + pass + + @abstractmethod + def delete( + self, + board_id: str, + ) -> None: + """Deletes a board.""" + pass + + @abstractmethod + def get_many( + self, + offset: int = 0, + limit: int = 10, + ) -> OffsetPaginatedResults[BoardDTO]: + """Gets many boards.""" + pass + + @abstractmethod + def get_all( + self, + ) -> list[BoardDTO]: + """Gets all boards.""" + pass + + +class BoardServiceDependencies: + """Service dependencies for the BoardService.""" + + board_image_records: BoardImageRecordStorageBase + board_records: BoardRecordStorageBase + image_records: ImageRecordStorageBase + urls: UrlServiceBase + logger: Logger + + def __init__( + self, + board_image_record_storage: BoardImageRecordStorageBase, + image_record_storage: ImageRecordStorageBase, + board_record_storage: BoardRecordStorageBase, + url: UrlServiceBase, + logger: Logger, + ): + self.board_image_records = board_image_record_storage + self.image_records = image_record_storage + self.board_records = board_record_storage + self.urls = url + self.logger = logger + + +class BoardService(BoardServiceABC): + _services: BoardServiceDependencies + + def __init__(self, services: BoardServiceDependencies): + self._services = services + + def create( + self, + board_name: str, + ) -> BoardDTO: + board_record = self._services.board_records.save(board_name) + return board_record_to_dto(board_record, None, 0) + + def get_dto(self, board_id: str) -> BoardDTO: + board_record = self._services.board_records.get(board_id) + cover_image = self._services.image_records.get_most_recent_image_for_board( + board_record.board_id + ) + if cover_image: + cover_image_name = cover_image.image_name + else: + cover_image_name = None + image_count = self._services.board_image_records.get_image_count_for_board( + board_id + ) + return board_record_to_dto(board_record, cover_image_name, image_count) + + def update( + self, + board_id: str, + changes: BoardChanges, + ) -> BoardDTO: + board_record = self._services.board_records.update(board_id, changes) + cover_image = self._services.image_records.get_most_recent_image_for_board( + board_record.board_id + ) + if cover_image: + cover_image_name = cover_image.image_name + else: + cover_image_name = None + + image_count = self._services.board_image_records.get_image_count_for_board( + board_id + ) + return board_record_to_dto(board_record, cover_image_name, image_count) + + def delete(self, board_id: str) -> None: + self._services.board_records.delete(board_id) + + def get_many( + self, offset: int = 0, limit: int = 10 + ) -> OffsetPaginatedResults[BoardDTO]: + board_records = self._services.board_records.get_many(offset, limit) + board_dtos = [] + for r in board_records.items: + cover_image = self._services.image_records.get_most_recent_image_for_board( + r.board_id + ) + if cover_image: + cover_image_name = cover_image.image_name + else: + cover_image_name = None + + image_count = self._services.board_image_records.get_image_count_for_board( + r.board_id + ) + board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) + + return OffsetPaginatedResults[BoardDTO]( + items=board_dtos, offset=offset, limit=limit, total=len(board_dtos) + ) + + def get_all(self) -> list[BoardDTO]: + board_records = self._services.board_records.get_all() + board_dtos = [] + for r in board_records: + cover_image = self._services.image_records.get_most_recent_image_for_board( + r.board_id + ) + if cover_image: + cover_image_name = cover_image.image_name + else: + cover_image_name = None + + image_count = self._services.board_image_records.get_image_count_for_board( + r.board_id + ) + board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) + + return board_dtos \ No newline at end of file diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 30b379ed8b..c34d2ca5c8 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageRecord]: """Gets a page of image records.""" pass @@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC): """Saves an image record.""" pass + @abstractmethod + def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None: + """Gets the most recent image for a board.""" + pass + class SqliteImageRecordStorage(ImageRecordStorageBase): _filename: str @@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): self._lock.release() def _create_tables(self) -> None: - """Creates the tables for the `images` database.""" + """Creates the `images` table.""" # Create the `images` table. self._cursor.execute( @@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id TEXT, metadata TEXT, is_intermediate BOOLEAN DEFAULT FALSE, + board_id TEXT, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), @@ -190,7 +197,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): AFTER UPDATE ON images FOR EACH ROW BEGIN - UPDATE images SET updated_at = current_timestamp + UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') WHERE image_name = old.image_name; END; """ @@ -259,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """, (changes.is_intermediate, image_name), ) + self._conn.commit() except sqlite3.Error as e: self._conn.rollback() @@ -273,38 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageRecord]: try: self._lock.acquire() # Manually build two queries - one for the count, one for the records + count_query = """--sql + SELECT COUNT(*) + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1 + """ - count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n""" - images_query = f"""SELECT * FROM images WHERE 1=1\n""" + images_query = """--sql + SELECT images.* + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1 + """ query_conditions = "" query_params = [] if image_origin is not None: - query_conditions += f"""AND image_origin = ?\n""" + query_conditions += """--sql + AND images.image_origin = ? + """ query_params.append(image_origin.value) if categories is not None: - ## Convert the enum values to unique list of strings + # Convert the enum values to unique list of strings category_strings = list(map(lambda c: c.value, set(categories))) # Create the correct length of placeholders placeholders = ",".join("?" * len(category_strings)) - query_conditions += f"AND image_category IN ( {placeholders} )\n" + + query_conditions += f"""--sql + AND images.image_category IN ( {placeholders} ) + """ # Unpack the included categories into the query params for c in category_strings: query_params.append(c) if is_intermediate is not None: - query_conditions += f"""AND is_intermediate = ?\n""" + query_conditions += """--sql + AND images.is_intermediate = ? + """ + query_params.append(is_intermediate) - query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n""" + if board_id is not None: + query_conditions += """--sql + AND board_images.board_id = ? + """ + + query_params.append(board_id) + + query_pagination = """--sql + ORDER BY images.created_at DESC LIMIT ? OFFSET ? + """ # Final images query with pagination images_query += query_conditions + query_pagination + ";" @@ -321,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): count_query += query_conditions + ";" count_params = query_params.copy() self._cursor.execute(count_query, count_params) - count = self._cursor.fetchone()[0] + count = cast(int, self._cursor.fetchone()[0]) except sqlite3.Error as e: self._conn.rollback() raise e @@ -412,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): raise ImageRecordSaveException from e finally: self._lock.release() + + def get_most_recent_image_for_board( + self, board_id: str + ) -> Union[ImageRecord, None]: + try: + self._lock.acquire() + self._cursor.execute( + """--sql + SELECT images.* + FROM images + JOIN board_images ON images.image_name = board_images.image_name + WHERE board_images.board_id = ? + ORDER BY images.created_at DESC + LIMIT 1; + """, + (board_id,), + ) + + result = cast(Union[sqlite3.Row, None], self._cursor.fetchone()) + finally: + self._lock.release() + if result is None: + return None + + return deserialize_image_record(dict(result)) diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 9f7188f607..542f874f1d 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -10,6 +10,7 @@ from invokeai.app.models.image import ( InvalidOriginException, ) from invokeai.app.models.metadata import ImageMetadata +from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase from invokeai.app.services.image_record_storage import ( ImageRecordDeleteException, ImageRecordNotFoundException, @@ -49,7 +50,7 @@ class ImageServiceABC(ABC): image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, - intermediate: bool = False, + is_intermediate: bool = False, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass @@ -79,7 +80,7 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_path(self, image_name: str) -> str: + def get_path(self, image_name: str, thumbnail: bool = False) -> str: """Gets an image's path.""" pass @@ -101,6 +102,7 @@ class ImageServiceABC(ABC): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageDTO]: """Gets a paginated list of image DTOs.""" pass @@ -114,8 +116,9 @@ class ImageServiceABC(ABC): class ImageServiceDependencies: """Service dependencies for the ImageService.""" - records: ImageRecordStorageBase - files: ImageFileStorageBase + image_records: ImageRecordStorageBase + image_files: ImageFileStorageBase + board_image_records: BoardImageRecordStorageBase metadata: MetadataServiceBase urls: UrlServiceBase logger: Logger @@ -126,14 +129,16 @@ class ImageServiceDependencies: self, image_record_storage: ImageRecordStorageBase, image_file_storage: ImageFileStorageBase, + board_image_record_storage: BoardImageRecordStorageBase, metadata: MetadataServiceBase, url: UrlServiceBase, logger: Logger, names: NameServiceBase, graph_execution_manager: ItemStorageABC["GraphExecutionState"], ): - self.records = image_record_storage - self.files = image_file_storage + self.image_records = image_record_storage + self.image_files = image_file_storage + self.board_image_records = board_image_record_storage self.metadata = metadata self.urls = url self.logger = logger @@ -144,25 +149,8 @@ class ImageServiceDependencies: class ImageService(ImageServiceABC): _services: ImageServiceDependencies - def __init__( - self, - image_record_storage: ImageRecordStorageBase, - image_file_storage: ImageFileStorageBase, - metadata: MetadataServiceBase, - url: UrlServiceBase, - logger: Logger, - names: NameServiceBase, - graph_execution_manager: ItemStorageABC["GraphExecutionState"], - ): - self._services = ImageServiceDependencies( - image_record_storage=image_record_storage, - image_file_storage=image_file_storage, - metadata=metadata, - url=url, - logger=logger, - names=names, - graph_execution_manager=graph_execution_manager, - ) + def __init__(self, services: ImageServiceDependencies): + self._services = services def create( self, @@ -187,7 +175,7 @@ class ImageService(ImageServiceABC): try: # TODO: Consider using a transaction here to ensure consistency between storage and database - created_at = self._services.records.save( + self._services.image_records.save( # Non-nullable fields image_name=image_name, image_origin=image_origin, @@ -202,35 +190,15 @@ class ImageService(ImageServiceABC): metadata=metadata, ) - self._services.files.save( + self._services.image_files.save( image_name=image_name, image=image, metadata=metadata, ) - image_url = self._services.urls.get_image_url(image_name) - thumbnail_url = self._services.urls.get_image_url(image_name, True) + image_dto = self.get_dto(image_name) - return ImageDTO( - # Non-nullable fields - image_name=image_name, - image_origin=image_origin, - image_category=image_category, - width=width, - height=height, - # Nullable fields - node_id=node_id, - session_id=session_id, - metadata=metadata, - # Meta fields - created_at=created_at, - updated_at=created_at, # this is always the same as the created_at at this time - deleted_at=None, - is_intermediate=is_intermediate, - # Extra non-nullable fields for DTO - image_url=image_url, - thumbnail_url=thumbnail_url, - ) + return image_dto except ImageRecordSaveException: self._services.logger.error("Failed to save image record") raise @@ -247,7 +215,7 @@ class ImageService(ImageServiceABC): changes: ImageRecordChanges, ) -> ImageDTO: try: - self._services.records.update(image_name, changes) + self._services.image_records.update(image_name, changes) return self.get_dto(image_name) except ImageRecordSaveException: self._services.logger.error("Failed to update image record") @@ -258,7 +226,7 @@ class ImageService(ImageServiceABC): def get_pil_image(self, image_name: str) -> PILImageType: try: - return self._services.files.get(image_name) + return self._services.image_files.get(image_name) except ImageFileNotFoundException: self._services.logger.error("Failed to get image file") raise @@ -268,7 +236,7 @@ class ImageService(ImageServiceABC): def get_record(self, image_name: str) -> ImageRecord: try: - return self._services.records.get(image_name) + return self._services.image_records.get(image_name) except ImageRecordNotFoundException: self._services.logger.error("Image record not found") raise @@ -278,12 +246,13 @@ class ImageService(ImageServiceABC): def get_dto(self, image_name: str) -> ImageDTO: try: - image_record = self._services.records.get(image_name) + image_record = self._services.image_records.get(image_name) image_dto = image_record_to_dto( image_record, self._services.urls.get_image_url(image_name), self._services.urls.get_image_url(image_name, True), + self._services.board_image_records.get_board_for_image(image_name), ) return image_dto @@ -296,14 +265,14 @@ class ImageService(ImageServiceABC): def get_path(self, image_name: str, thumbnail: bool = False) -> str: try: - return self._services.files.get_path(image_name, thumbnail) + return self._services.image_files.get_path(image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e def validate_path(self, path: str) -> bool: try: - return self._services.files.validate_path(path) + return self._services.image_files.validate_path(path) except Exception as e: self._services.logger.error("Problem validating image path") raise e @@ -322,14 +291,16 @@ class ImageService(ImageServiceABC): image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, ) -> OffsetPaginatedResults[ImageDTO]: try: - results = self._services.records.get_many( + results = self._services.image_records.get_many( offset, limit, image_origin, categories, is_intermediate, + board_id, ) image_dtos = list( @@ -338,6 +309,9 @@ class ImageService(ImageServiceABC): r, self._services.urls.get_image_url(r.image_name), self._services.urls.get_image_url(r.image_name, True), + self._services.board_image_records.get_board_for_image( + r.image_name + ), ), results.items, ) @@ -355,8 +329,8 @@ class ImageService(ImageServiceABC): def delete(self, image_name: str): try: - self._services.files.delete(image_name) - self._services.records.delete(image_name) + self._services.image_files.delete(image_name) + self._services.image_records.delete(image_name) except ImageRecordDeleteException: self._services.logger.error(f"Failed to delete image record") raise diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 1f910253e5..10d1d91920 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -4,7 +4,9 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from logging import Logger - from invokeai.app.services.images import ImageService + from invokeai.app.services.board_images import BoardImagesServiceABC + from invokeai.app.services.boards import BoardServiceABC + from invokeai.app.services.images import ImageServiceABC from invokeai.backend import ModelManager from invokeai.app.services.events import EventServiceBase from invokeai.app.services.latent_storage import LatentsStorageBase @@ -26,9 +28,9 @@ class InvocationServices: model_manager: "ModelManager" restoration: "RestorationServices" configuration: "InvokeAISettings" - images: "ImageService" - - # NOTE: we must forward-declare any types that include invocations, since invocations can use services + images: "ImageServiceABC" + boards: "BoardServiceABC" + board_images: "BoardImagesServiceABC" graph_library: "ItemStorageABC"["LibraryGraph"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] processor: "InvocationProcessorABC" @@ -39,7 +41,9 @@ class InvocationServices: events: "EventServiceBase", logger: "Logger", latents: "LatentsStorageBase", - images: "ImageService", + images: "ImageServiceABC", + boards: "BoardServiceABC", + board_images: "BoardImagesServiceABC", queue: "InvocationQueueABC", graph_library: "ItemStorageABC"["LibraryGraph"], graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], @@ -52,9 +56,12 @@ class InvocationServices: self.logger = logger self.latents = latents self.images = images + self.boards = boards + self.board_images = board_images self.queue = queue self.graph_library = graph_library self.graph_execution_manager = graph_execution_manager self.processor = processor self.restoration = restoration self.configuration = configuration + self.boards = boards diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index c212ff6a72..8b46b17ad0 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -5,7 +5,7 @@ from __future__ import annotations import torch from abc import ABC, abstractmethod from pathlib import Path -from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING +from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING from dataclasses import dataclass from invokeai.backend.model_management.model_manager import ( @@ -69,19 +69,6 @@ class ModelManagerServiceBase(ABC): ) -> bool: pass - @abstractmethod - def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]: - """ - Returns the name and typeof the default model, or None - if none is defined. - """ - pass - - @abstractmethod - def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType): - """Sets the default model to the indicated name.""" - pass - @abstractmethod def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: """ @@ -270,17 +257,6 @@ class ModelManagerService(ModelManagerServiceBase): model_type, ) - def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]: - """ - Returns the name of the default model, or None - if none is defined. - """ - return self.mgr.default_model() - - def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType): - """Sets the default model to the indicated name.""" - self.mgr.set_default_model(model_name, base_model, model_type) - def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: """ Given a model name returns a dict-like (OmegaConf) object describing it. @@ -297,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase): self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> dict: + ) -> list[dict]: + # ) -> dict: """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } + Return a list of models. """ return self.mgr.list_models(base_model, model_type) diff --git a/invokeai/app/services/models/board_record.py b/invokeai/app/services/models/board_record.py new file mode 100644 index 0000000000..bf5401b209 --- /dev/null +++ b/invokeai/app/services/models/board_record.py @@ -0,0 +1,62 @@ +from typing import Optional, Union +from datetime import datetime +from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr +from invokeai.app.util.misc import get_iso_timestamp + + +class BoardRecord(BaseModel): + """Deserialized board record.""" + + board_id: str = Field(description="The unique ID of the board.") + """The unique ID of the board.""" + board_name: str = Field(description="The name of the board.") + """The name of the board.""" + created_at: Union[datetime, str] = Field( + description="The created timestamp of the board." + ) + """The created timestamp of the image.""" + updated_at: Union[datetime, str] = Field( + description="The updated timestamp of the board." + ) + """The updated timestamp of the image.""" + deleted_at: Union[datetime, str, None] = Field( + description="The deleted timestamp of the board." + ) + """The updated timestamp of the image.""" + cover_image_name: Optional[str] = Field( + description="The name of the cover image of the board." + ) + """The name of the cover image of the board.""" + + +class BoardDTO(BoardRecord): + """Deserialized board record with cover image URL and image count.""" + + cover_image_name: Optional[str] = Field( + description="The name of the board's cover image." + ) + """The URL of the thumbnail of the most recent image in the board.""" + image_count: int = Field(description="The number of images in the board.") + """The number of images in the board.""" + + +def deserialize_board_record(board_dict: dict) -> BoardRecord: + """Deserializes a board record.""" + + # Retrieve all the values, setting "reasonable" defaults if they are not present. + + board_id = board_dict.get("board_id", "unknown") + board_name = board_dict.get("board_name", "unknown") + cover_image_name = board_dict.get("cover_image_name", "unknown") + created_at = board_dict.get("created_at", get_iso_timestamp()) + updated_at = board_dict.get("updated_at", get_iso_timestamp()) + deleted_at = board_dict.get("deleted_at", get_iso_timestamp()) + + return BoardRecord( + board_id=board_id, + board_name=board_name, + cover_image_name=cover_image_name, + created_at=created_at, + updated_at=updated_at, + deleted_at=deleted_at, + ) diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index d971d65916..cc02016cf9 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel): class ImageDTO(ImageRecord, ImageUrlsDTO): - """Deserialized image record, enriched for the frontend with URLs.""" + """Deserialized image record, enriched for the frontend.""" + board_id: Union[str, None] = Field( + description="The id of the board the image belongs to, if one exists." + ) + """The id of the board the image belongs to, if one exists.""" pass def image_record_to_dto( - image_record: ImageRecord, image_url: str, thumbnail_url: str + image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None] ) -> ImageDTO: """Converts an image record to an image DTO.""" return ImageDTO( **image_record.dict(), image_url=image_url, thumbnail_url=thumbnail_url, + board_id=board_id, ) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index e56ec2a0d2..4e2c789c07 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -266,6 +266,8 @@ class ModelManager(object): for model_key, model_config in config.items(): model_name, base_model, model_type = self.parse_key(model_key) model_class = MODEL_CLASSES[base_model][model_type] + # alias for config file + model_config["model_format"] = model_config.pop("format") self.models[model_key] = model_class.create_config(**model_config) # check config version number and update on disk/RAM if necessary @@ -446,38 +448,6 @@ class ModelManager(object): _cache = self.cache, ) - def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]: - """ - Returns the name of the default model, or None - if none is defined. - """ - for model_key, model_config in self.models.items(): - if model_config.default: - return self.parse_key(model_key) - - for model_key, _ in self.models.items(): - return self.parse_key(model_key) - else: - return None # TODO: or redo as (None, None, None) - - def set_default_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - ) -> None: - """ - Set the default model. The change will not take - effect until you call model_manager.commit() - """ - - model_key = self.model_key(model_name, base_model, model_type) - if model_key not in self.models: - raise Exception(f"Unknown model: {model_key}") - - for cur_model_key, config in self.models.items(): - config.default = cur_model_key == model_key - def model_info( self, model_name: str, @@ -504,9 +474,9 @@ class ModelManager(object): self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, - ) -> Dict[str, Dict[str, str]]: + ) -> list[dict]: """ - Return a dict of models, in format [base_model][model_type][model_name] + Return a list of models. Please use model_manager.models() to get all the model names, model_manager.model_info('model-name') to get the stanza for the model @@ -514,7 +484,7 @@ class ModelManager(object): object derived from models.yaml """ - models = dict() + models = [] for model_key in sorted(self.models, key=str.casefold): model_config = self.models[model_key] @@ -524,18 +494,16 @@ class ModelManager(object): if model_type is not None and cur_model_type != model_type: continue - if cur_base_model not in models: - models[cur_base_model] = dict() - if cur_model_type not in models[cur_base_model]: - models[cur_base_model][cur_model_type] = dict() - - models[cur_base_model][cur_model_type][cur_model_name] = dict( + model_dict = dict( **model_config.dict(exclude_defaults=True), + # OpenAPIModelInfoBase name=cur_model_name, base_model=cur_base_model, type=cur_model_type, ) + models.append(model_dict) + return models def print_models(self) -> None: @@ -647,7 +615,9 @@ class ModelManager(object): model_class = MODEL_CLASSES[base_model][model_type] if model_class.save_to_config: # TODO: or exclude_unset better fits here? - data_to_save[model_key] = model_config.dict(exclude_defaults=True) + data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"}) + # alias for config file + data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format") yaml_str = OmegaConf.to_yaml(data_to_save) config_file_path = conf_file or self.config_path diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 40995498bf..6975d45f93 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -1,3 +1,7 @@ +import inspect +from enum import Enum +from pydantic import BaseModel +from typing import Literal, get_origin from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .vae import VaeModel @@ -29,10 +33,63 @@ MODEL_CLASSES = { #}, } -def get_all_model_configs(): - configs = set() - for models in MODEL_CLASSES.values(): - for _, model in models.items(): - configs.update(model._get_configs().values()) - configs.discard(None) - return list(configs) # TODO: set, list or tuple +MODEL_CONFIGS = list() +OPENAPI_MODEL_CONFIGS = list() + +class OpenAPIModelInfoBase(BaseModel): + name: str + base_model: BaseModelType + type: ModelType + + +for base_model, models in MODEL_CLASSES.items(): + for model_type, model_class in models.items(): + model_configs = set(model_class._get_configs().values()) + model_configs.discard(None) + MODEL_CONFIGS.extend(model_configs) + + for cfg in model_configs: + model_name, cfg_name = cfg.__qualname__.split('.')[-2:] + openapi_cfg_name = model_name + cfg_name + if openapi_cfg_name in vars(): + continue + + api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict( + __annotations__ = dict( + type=Literal[model_type.value], + ), + )) + + #globals()[openapi_cfg_name] = api_wrapper + vars()[openapi_cfg_name] = api_wrapper + OPENAPI_MODEL_CONFIGS.append(api_wrapper) + +def get_model_config_enums(): + enums = list() + + for model_config in MODEL_CONFIGS: + fields = inspect.get_annotations(model_config) + try: + field = fields["model_format"] + except: + raise Exception("format field not found") + + # model_format: None + # model_format: SomeModelFormat + # model_format: Literal[SomeModelFormat.Diffusers] + # model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint] + + if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): + enums.append(field) + + elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__): + enums.append(type(field.__args__[0])) + + elif field is None: + pass + + else: + raise Exception(f"Unsupported format definition in {model_configs.__qualname__}") + + return enums + diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index f18099b4e7..ef354ecc07 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -48,12 +48,10 @@ class ModelError(str, Enum): class ModelConfigBase(BaseModel): path: str # or Path - #name: str # not included as present in model key description: Optional[str] = Field(None) - format: Optional[str] = Field(None) - default: Optional[bool] = Field(False) + model_format: Optional[str] = Field(None) # do not save to config - error: Optional[ModelError] = Field(None, exclude=True) + error: Optional[ModelError] = Field(None) class Config: use_enum_values = True @@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta): def _hf_definition_to_type(self, subtypes: List[str]) -> Type: if len(subtypes) < 2: raise Exception("Invalid subfolder definition!") + if all(t is None for t in subtypes): + return None + elif any(t is None for t in subtypes): + raise Exception(f"Unsupported definition: {subtypes}") + if subtypes[0] in ["diffusers", "transformers"]: res_type = sys.modules[subtypes[0]] subtypes = subtypes[1:] @@ -122,46 +125,41 @@ class ModelBase(metaclass=ABCMeta): continue fields = inspect.get_annotations(value) - if "format" not in fields: - raise Exception("Invalid config definition - format field not found") + try: + field = fields["model_format"] + except: + raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})") - format_type = typing.get_origin(fields["format"]) - if format_type not in {None, Literal, Union}: - raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") + if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): + for model_format in field: + configs[model_format.value] = value - if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__): - raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") + elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__): + for model_format in field.__args__: + configs[model_format.value] = value + + elif field is None: + configs[None] = value - if format_type == Union: - f_fields = fields["format"].__args__ else: - f_fields = (fields["format"],) - - - for field in f_fields: - if field is None: - format_name = None - else: - format_name = field.__args__[0] - - configs[format_name] = value # TODO: error when override(multiple)? - + raise Exception(f"Unsupported format definition in {cls.__qualname__}") cls.__configs = configs return cls.__configs @classmethod def create_config(cls, **kwargs) -> ModelConfigBase: - if "format" not in kwargs: - raise Exception("Field 'format' not found in model config") + if "model_format" not in kwargs: + raise Exception("Field 'model_format' not found in model config") + configs = cls._get_configs() - return configs[kwargs["format"]](**kwargs) + return configs[kwargs["model_format"]](**kwargs) @classmethod def probe_config(cls, path: str, **kwargs) -> ModelConfigBase: return cls.create_config( path=path, - format=cls.detect_format(path), + model_format=cls.detect_format(path), ) @classmethod diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index d75c55010a..9563f87afd 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -1,5 +1,6 @@ import os import torch +from enum import Enum from pathlib import Path from typing import Optional, Union, Literal from .base import ( @@ -14,12 +15,16 @@ from .base import ( classproperty, ) +class ControlNetModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" + class ControlNetModel(ModelBase): #model_class: Type #model_size: int class Config(ModelConfigBase): - format: Union[Literal["checkpoint"], Literal["diffusers"]] + model_format: ControlNetModelFormat def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.ControlNet @@ -69,9 +74,9 @@ class ControlNetModel(ModelBase): @classmethod def detect_format(cls, path: str): if os.path.isdir(path): - return "diffusers" + return ControlNetModelFormat.Diffusers else: - return "checkpoint" + return ControlNetModelFormat.Checkpoint @classmethod def convert_if_required( @@ -81,7 +86,7 @@ class ControlNetModel(ModelBase): config: ModelConfigBase, # empty config or config of parent model base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) != "diffusers": - raise NotImlemetedError("Checkpoint controlnet models currently unsupported") + if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers: + raise NotImplementedError("Checkpoint controlnet models currently unsupported") else: return model_path diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index bcf3224ece..59feacde06 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -1,5 +1,6 @@ import os import torch +from enum import Enum from typing import Optional, Union, Literal from .base import ( ModelBase, @@ -12,11 +13,15 @@ from .base import ( # TODO: naming from ..lora import LoRAModel as LoRAModelRaw +class LoRAModelFormat(str, Enum): + LyCORIS = "lycoris" + Diffusers = "diffusers" + class LoRAModel(ModelBase): #model_size: int class Config(ModelConfigBase): - format: Union[Literal["lycoris"], Literal["diffusers"]] + model_format: LoRAModelFormat # TODO: def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Lora @@ -52,9 +57,9 @@ class LoRAModel(ModelBase): @classmethod def detect_format(cls, path: str): if os.path.isdir(path): - return "diffusers" + return LoRAModelFormat.Diffusers else: - return "lycoris" + return LoRAModelFormat.LyCORIS @classmethod def convert_if_required( @@ -64,7 +69,7 @@ class LoRAModel(ModelBase): config: ModelConfigBase, base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) == "diffusers": + if cls.detect_format(model_path) == LoRAModelFormat.Diffusers: # TODO: add diffusers lora when it stabilizes a bit raise NotImplementedError("Diffusers lora not supported") else: diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index bd519c88c8..f169326571 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -1,5 +1,6 @@ import os import json +from enum import Enum from pydantic import Field from pathlib import Path from typing import Literal, Optional, Union @@ -19,16 +20,19 @@ from .base import ( from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf +class StableDiffusion1ModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" class StableDiffusion1Model(DiffusersModel): class DiffusersConfig(ModelConfigBase): - format: Literal["diffusers"] + model_format: Literal[StableDiffusion1ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType class CheckpointConfig(ModelConfigBase): - format: Literal["checkpoint"] + model_format: Literal[StableDiffusion1ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) variant: ModelVariantType @@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel): def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) - if model_format == "checkpoint": + if model_format == StableDiffusion1ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] @@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel): checkpoint = checkpoint.get('state_dict', checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - elif model_format == "diffusers": + elif model_format == StableDiffusion1ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: @@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel): return cls.create_config( path=path, - format=model_format, + model_format=model_format, config=ckpt_config_path, variant=variant, @@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel): @classmethod def detect_format(cls, model_path: str): if os.path.isdir(model_path): - return "diffusers" + return StableDiffusion1ModelFormat.Diffusers else: - return "checkpoint" + return StableDiffusion1ModelFormat.Checkpoint @classmethod def convert_if_required( @@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel): else: return model_path +class StableDiffusion2ModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" class StableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly class DiffusersConfig(ModelConfigBase): - format: Literal["diffusers"] + model_format: Literal[StableDiffusion2ModelFormat.Diffusers] vae: Optional[str] = Field(None) variant: ModelVariantType prediction_type: SchedulerPredictionType upcast_attention: bool class CheckpointConfig(ModelConfigBase): - format: Literal["checkpoint"] + model_format: Literal[StableDiffusion2ModelFormat.Checkpoint] vae: Optional[str] = Field(None) config: Optional[str] = Field(None) variant: ModelVariantType @@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel): def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) ckpt_config_path = kwargs.get("config", None) - if model_format == "checkpoint": + if model_format == StableDiffusion2ModelFormat.Checkpoint: if ckpt_config_path: ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] @@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel): checkpoint = checkpoint.get('state_dict', checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - elif model_format == "diffusers": + elif model_format == StableDiffusion2ModelFormat.Diffusers: unet_config_path = os.path.join(path, "unet", "config.json") if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: @@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel): return cls.create_config( path=path, - format=model_format, + model_format=model_format, config=ckpt_config_path, variant=variant, @@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel): @classmethod def detect_format(cls, model_path: str): if os.path.isdir(model_path): - return "diffusers" + return StableDiffusion2ModelFormat.Diffusers else: - return "checkpoint" + return StableDiffusion2ModelFormat.Checkpoint @classmethod def convert_if_required( @@ -281,8 +288,8 @@ def _convert_ckpt_and_cache( prediction_type = SchedulerPredictionType.Epsilon elif version == BaseModelType.StableDiffusion2: - upcast_attention = config.upcast_attention - prediction_type = config.prediction_type + upcast_attention = model_config.upcast_attention + prediction_type = model_config.prediction_type else: raise Exception(f"Unknown model provided: {version}") diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index 66847f53eb..9a032218f0 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase): #model_size: int class Config(ModelConfigBase): - format: None + model_format: None def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.TextualInversion diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index 1edb57ccc4..76133b074d 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -1,5 +1,7 @@ import os import torch +import safetensors +from enum import Enum from pathlib import Path from typing import Optional, Union, Literal from .base import ( @@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig from diffusers.utils import is_safetensors_available from omegaconf import OmegaConf +class VaeModelFormat(str, Enum): + Checkpoint = "checkpoint" + Diffusers = "diffusers" + class VaeModel(ModelBase): #vae_class: Type #model_size: int class Config(ModelConfigBase): - format: Union[Literal["checkpoint"], Literal["diffusers"]] + model_format: VaeModelFormat def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Vae @@ -70,9 +76,9 @@ class VaeModel(ModelBase): @classmethod def detect_format(cls, path: str): if os.path.isdir(path): - return "diffusers" + return VaeModelFormat.Diffusers else: - return "checkpoint" + return VaeModelFormat.Checkpoint @classmethod def convert_if_required( @@ -82,7 +88,7 @@ class VaeModel(ModelBase): config: ModelConfigBase, # empty config or config of parent model base_model: BaseModelType, ) -> str: - if cls.detect_format(model_path) != "diffusers": + if cls.detect_format(model_path) == VaeModelFormat.Checkpoint: return _convert_vae_ckpt_and_cache( weights_path=model_path, output_path=output_path, diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index ddc6dace27..55fcc97745 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys'; import Toaster from './Toaster'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; +import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; +import { useListModelsQuery } from 'services/apiSlice'; const DEFAULT_CONFIG = {}; @@ -45,6 +47,18 @@ const App = ({ const isApplicationReady = useIsApplicationReady(); + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); + const { data: controlnetModels } = useListModelsQuery({ + model_type: 'controlnet', + }); + const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' }); + const { data: loraModels } = useListModelsQuery({ model_type: 'lora' }); + const { data: embeddingModels } = useListModelsQuery({ + model_type: 'embedding', + }); + const [loadingOverridden, setLoadingOverridden] = useState(false); const dispatch = useAppDispatch(); @@ -143,6 +157,7 @@ const App = ({ + diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 0537d1de2a..141e62652d 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -21,6 +21,8 @@ import { DeleteImageContext, DeleteImageContextProvider, } from 'app/contexts/DeleteImageContext'; +import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal'; +import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext'; const App = lazy(() => import('./App')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); @@ -76,11 +78,13 @@ const InvokeAIUI = ({ - + + + diff --git a/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx b/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx new file mode 100644 index 0000000000..f5a856d3d8 --- /dev/null +++ b/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx @@ -0,0 +1,89 @@ +import { useDisclosure } from '@chakra-ui/react'; +import { PropsWithChildren, createContext, useCallback, useState } from 'react'; +import { ImageDTO } from 'services/api'; +import { useAddImageToBoardMutation } from 'services/apiSlice'; + +export type ImageUsage = { + isInitialImage: boolean; + isCanvasImage: boolean; + isNodesImage: boolean; + isControlNetImage: boolean; +}; + +type AddImageToBoardContextValue = { + /** + * Whether the move image dialog is open. + */ + isOpen: boolean; + /** + * Closes the move image dialog. + */ + onClose: () => void; + /** + * The image pending movement + */ + image?: ImageDTO; + onClickAddToBoard: (image: ImageDTO) => void; + handleAddToBoard: (boardId: string) => void; +}; + +export const AddImageToBoardContext = + createContext({ + isOpen: false, + onClose: () => undefined, + onClickAddToBoard: () => undefined, + handleAddToBoard: () => undefined, + }); + +type Props = PropsWithChildren; + +export const AddImageToBoardContextProvider = (props: Props) => { + const [imageToMove, setImageToMove] = useState(); + const { isOpen, onOpen, onClose } = useDisclosure(); + + const [addImageToBoard, result] = useAddImageToBoardMutation(); + + // Clean up after deleting or dismissing the modal + const closeAndClearImageToDelete = useCallback(() => { + setImageToMove(undefined); + onClose(); + }, [onClose]); + + const onClickAddToBoard = useCallback( + (image?: ImageDTO) => { + if (!image) { + return; + } + setImageToMove(image); + onOpen(); + }, + [setImageToMove, onOpen] + ); + + const handleAddToBoard = useCallback( + (boardId: string) => { + if (imageToMove) { + addImageToBoard({ + board_id: boardId, + image_name: imageToMove.image_name, + }); + closeAndClearImageToDelete(); + } + }, + [addImageToBoard, closeAndClearImageToDelete, imageToMove] + ); + + return ( + + {props.children} + + ); +}; diff --git a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx b/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx index 8263b48114..d01298944b 100644 --- a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx +++ b/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx @@ -35,25 +35,23 @@ export const selectImageUsage = createSelector( (state: RootState, image_name?: string) => image_name, ], (generation, canvas, nodes, controlNet, image_name) => { - const isInitialImage = generation.initialImage?.image_name === image_name; + const isInitialImage = generation.initialImage?.imageName === image_name; const isCanvasImage = canvas.layerState.objects.some( - (obj) => obj.kind === 'image' && obj.image.image_name === image_name + (obj) => obj.kind === 'image' && obj.imageName === image_name ); const isNodesImage = nodes.nodes.some((node) => { return some( node.data.inputs, - (input) => - input.type === 'image' && input.value?.image_name === image_name + (input) => input.type === 'image' && input.value === image_name ); }); const isControlNetImage = some( controlNet.controlNets, (c) => - c.controlImage?.image_name === image_name || - c.processedControlImage?.image_name === image_name + c.controlImage === image_name || c.processedControlImage === image_name ); const imageUsage: ImageUsage = { diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index 5025ca081a..cb18d48301 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist'; -import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist'; import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist'; import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist'; import { omit } from 'lodash-es'; @@ -18,7 +17,6 @@ const serializationDenylist: { gallery: galleryPersistDenylist, generation: generationPersistDenylist, lightbox: lightboxPersistDenylist, - models: modelsPersistDenylist, nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts index c6af5f3612..8f40b0bb59 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts @@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialConfigState } from 'features/system/store/configSlice'; -import { initialModelsState } from 'features/system/store/modelSlice'; import { initialSystemState } from 'features/system/store/systemSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialUIState } from 'features/ui/store/uiSlice'; @@ -21,7 +20,6 @@ const initialStates: { gallery: initialGalleryState, generation: initialGenerationState, lightbox: initialLightboxState, - models: initialModelsState, nodes: initialNodesState, postprocessing: initialPostprocessingState, system: initialSystemState, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 8c073e81d6..cb641d00db 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -73,6 +73,15 @@ import { addImageCategoriesChangedListener } from './listeners/imageCategoriesCh import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect'; +import { + addImageAddedToBoardFulfilledListener, + addImageAddedToBoardRejectedListener, +} from './listeners/imageAddedToBoard'; +import { addBoardIdSelectedListener } from './listeners/boardIdSelected'; +import { + addImageRemovedFromBoardFulfilledListener, + addImageRemovedFromBoardRejectedListener, +} from './listeners/imageRemovedFromBoard'; export const listenerMiddleware = createListenerMiddleware(); @@ -92,6 +101,12 @@ export type AppListenerEffect = ListenerEffect< AppDispatch >; +/** + * The RTK listener middleware is a lightweight alternative sagas/observables. + * + * Most side effect logic should live in a listener. + */ + // Image uploaded addImageUploadedFulfilledListener(); addImageUploadedRejectedListener(); @@ -183,3 +198,10 @@ addControlNetAutoProcessListener(); // Update image URLs on connect addUpdateImageUrlsOnConnectListener(); + +// Boards +addImageAddedToBoardFulfilledListener(); +addImageAddedToBoardRejectedListener(); +addImageRemovedFromBoardFulfilledListener(); +addImageRemovedFromBoardRejectedListener(); +addBoardIdSelectedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts new file mode 100644 index 0000000000..eab4389ceb --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts @@ -0,0 +1,99 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { boardIdSelected } from 'features/gallery/store/boardSlice'; +import { selectImagesAll } from 'features/gallery/store/imagesSlice'; +import { IMAGES_PER_PAGE, receivedPageOfImages } from 'services/thunks/image'; +import { api } from 'services/apiSlice'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; + +const moduleLog = log.child({ namespace: 'boards' }); + +export const addBoardIdSelectedListener = () => { + startAppListening({ + actionCreator: boardIdSelected, + effect: (action, { getState, dispatch }) => { + const boardId = action.payload; + + // we need to check if we need to fetch more images + + const state = getState(); + const allImages = selectImagesAll(state); + + if (!boardId) { + // a board was unselected + dispatch(imageSelected(allImages[0]?.image_name)); + return; + } + + const { categories } = state.images; + + const filteredImages = allImages.filter((i) => { + const isInCategory = categories.includes(i.image_category); + const isInSelectedBoard = boardId ? i.board_id === boardId : true; + return isInCategory && isInSelectedBoard; + }); + + // get the board from the cache + const { data: boards } = api.endpoints.listAllBoards.select()(state); + const board = boards?.find((b) => b.board_id === boardId); + + if (!board) { + // can't find the board in cache... + dispatch(imageSelected(allImages[0]?.image_name)); + return; + } + + dispatch(imageSelected(board.cover_image_name)); + + // if we haven't loaded one full page of images from this board, load more + if ( + filteredImages.length < board.image_count && + filteredImages.length < IMAGES_PER_PAGE + ) { + dispatch(receivedPageOfImages({ categories, boardId })); + } + }, + }); +}; + +export const addBoardIdSelected_changeSelectedImage_listener = () => { + startAppListening({ + actionCreator: boardIdSelected, + effect: (action, { getState, dispatch }) => { + const boardId = action.payload; + + const state = getState(); + + // we need to check if we need to fetch more images + + if (!boardId) { + // a board was unselected - we don't need to do anything + return; + } + + const { categories } = state.images; + + const filteredImages = selectImagesAll(state).filter((i) => { + const isInCategory = categories.includes(i.image_category); + const isInSelectedBoard = boardId ? i.board_id === boardId : true; + return isInCategory && isInSelectedBoard; + }); + + // get the board from the cache + const { data: boards } = api.endpoints.listAllBoards.select()(state); + const board = boards?.find((b) => b.board_id === boardId); + if (!board) { + // can't find the board in cache... + return; + } + + // if we haven't loaded one full page of images from this board, load more + if ( + filteredImages.length < board.image_count && + filteredImages.length < IMAGES_PER_PAGE + ) { + dispatch(receivedPageOfImages({ categories, boardId })); + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index ce1b515b84..7ff9a5118c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => { [controlNet.processorNode.id]: { ...controlNet.processorNode, is_intermediate: true, - image: pick(controlNet.controlImage, ['image_name']), + image: { image_name: controlNet.controlImage }, }, }, }; @@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => { dispatch( controlNetProcessedImageChanged({ controlNetId, - processedControlImage, + processedControlImage: processedControlImage.image_name, }) ); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard.ts new file mode 100644 index 0000000000..0f404cab68 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard.ts @@ -0,0 +1,40 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { imageMetadataReceived } from 'services/thunks/image'; +import { api } from 'services/apiSlice'; + +const moduleLog = log.child({ namespace: 'boards' }); + +export const addImageAddedToBoardFulfilledListener = () => { + startAppListening({ + matcher: api.endpoints.addImageToBoard.matchFulfilled, + effect: (action, { getState, dispatch }) => { + const { board_id, image_name } = action.meta.arg.originalArgs; + + moduleLog.debug( + { data: { board_id, image_name } }, + 'Image added to board' + ); + + dispatch( + imageMetadataReceived({ + imageName: image_name, + }) + ); + }, + }); +}; + +export const addImageAddedToBoardRejectedListener = () => { + startAppListening({ + matcher: api.endpoints.addImageToBoard.matchRejected, + effect: (action, { getState, dispatch }) => { + const { board_id, image_name } = action.meta.arg.originalArgs; + + moduleLog.debug( + { data: { board_id, image_name } }, + 'Problem adding image to board' + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts index 85d56d3913..8f01b8d7b8 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts @@ -12,12 +12,16 @@ export const addImageCategoriesChangedListener = () => { startAppListening({ actionCreator: imageCategoriesChanged, effect: (action, { getState, dispatch }) => { - const filteredImagesCount = selectFilteredImagesAsArray( - getState() - ).length; + const state = getState(); + const filteredImagesCount = selectFilteredImagesAsArray(state).length; if (!filteredImagesCount) { - dispatch(receivedPageOfImages()); + dispatch( + receivedPageOfImages({ + categories: action.payload, + boardId: state.boards.selectedBoardId, + }) + ); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index 4c0c057242..224aa0d2aa 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -6,15 +6,15 @@ import { clamp } from 'lodash-es'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageRemoved, - selectImagesEntities, selectImagesIds, } from 'features/gallery/store/imagesSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; +import { api } from 'services/apiSlice'; -const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); +const moduleLog = log.child({ namespace: 'image' }); /** * Called when the user requests an image deletion @@ -22,7 +22,7 @@ const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); export const addRequestedImageDeletionListener = () => { startAppListening({ actionCreator: requestedImageDeletion, - effect: (action, { dispatch, getState }) => { + effect: async (action, { dispatch, getState, condition }) => { const { image, imageUsage } = action.payload; const { image_name } = image; @@ -30,9 +30,8 @@ export const addRequestedImageDeletionListener = () => { const state = getState(); const selectedImage = state.gallery.selectedImage; - if (selectedImage && selectedImage.image_name === image_name) { + if (selectedImage === image_name) { const ids = selectImagesIds(state); - const entities = selectImagesEntities(state); const deletedImageIndex = ids.findIndex( (result) => result.toString() === image_name @@ -48,10 +47,8 @@ export const addRequestedImageDeletionListener = () => { const newSelectedImageId = filteredIds[newSelectedImageIndex]; - const newSelectedImage = entities[newSelectedImageId]; - if (newSelectedImageId) { - dispatch(imageSelected(newSelectedImage)); + dispatch(imageSelected(newSelectedImageId as string)); } else { dispatch(imageSelected()); } @@ -79,7 +76,21 @@ export const addRequestedImageDeletionListener = () => { dispatch(imageRemoved(image_name)); // Delete from server - dispatch(imageDeleted({ imageName: image_name })); + const { requestId } = dispatch(imageDeleted({ imageName: image_name })); + + // Wait for successful deletion, then trigger boards to re-fetch + const wasImageDeleted = await condition( + (action): action is ReturnType => + imageDeleted.fulfilled.match(action) && + action.meta.requestId === requestId, + 30000 + ); + + if (wasImageDeleted) { + dispatch( + api.util.invalidateTags([{ type: 'Board', id: image.board_id }]) + ); + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard.ts new file mode 100644 index 0000000000..40847ade3a --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard.ts @@ -0,0 +1,40 @@ +import { log } from 'app/logging/useLogger'; +import { startAppListening } from '..'; +import { imageMetadataReceived } from 'services/thunks/image'; +import { api } from 'services/apiSlice'; + +const moduleLog = log.child({ namespace: 'boards' }); + +export const addImageRemovedFromBoardFulfilledListener = () => { + startAppListening({ + matcher: api.endpoints.removeImageFromBoard.matchFulfilled, + effect: (action, { getState, dispatch }) => { + const { board_id, image_name } = action.meta.arg.originalArgs; + + moduleLog.debug( + { data: { board_id, image_name } }, + 'Image added to board' + ); + + dispatch( + imageMetadataReceived({ + imageName: image_name, + }) + ); + }, + }); +}; + +export const addImageRemovedFromBoardRejectedListener = () => { + startAppListening({ + matcher: api.endpoints.removeImageFromBoard.matchRejected, + effect: (action, { getState, dispatch }) => { + const { board_id, image_name } = action.meta.arg.originalArgs; + + moduleLog.debug( + { data: { board_id, image_name } }, + 'Problem adding image to board' + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index 40ed062353..fc44d206c8 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -46,7 +46,12 @@ export const addImageUploadedFulfilledListener = () => { if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') { const { controlNetId } = postUploadAction; - dispatch(controlNetImageChanged({ controlNetId, controlImage: image })); + dispatch( + controlNetImageChanged({ + controlNetId, + controlImage: image.image_name, + }) + ); return; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 3049d2c933..bf54e63836 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,9 +1,8 @@ -import { startAppListening } from '../..'; import { log } from 'app/logging/useLogger'; import { appSocketConnected, socketConnected } from 'services/events/actions'; import { receivedPageOfImages } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { receivedOpenAPISchema } from 'services/thunks/schema'; +import { startAppListening } from '../..'; const moduleLog = log.child({ namespace: 'socketio' }); @@ -15,16 +14,17 @@ export const addSocketConnectedEventListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { models, nodes, config, images } = getState(); + const { nodes, config, images } = getState(); const { disabledTabs } = config; if (!images.ids.length) { - dispatch(receivedPageOfImages()); - } - - if (!models.ids.length) { - dispatch(receivedModels()); + dispatch( + receivedPageOfImages({ + categories: ['general'], + isIntermediate: false, + }) + ); } if (!nodes.schema && !disabledTabs.includes('nodes')) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index c9ab894ddb..c204f0bdfb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -9,6 +9,7 @@ import { imageMetadataReceived } from 'services/thunks/image'; import { sessionCanceled } from 'services/thunks/session'; import { isImageOutput } from 'services/types/guards'; import { progressImageSet } from 'features/system/store/systemSlice'; +import { api } from 'services/apiSlice'; const moduleLog = log.child({ namespace: 'socketio' }); const nodeDenylist = ['dataURL_image']; @@ -24,7 +25,8 @@ export const addInvocationCompleteEventListener = () => { const sessionId = action.payload.data.graph_execution_state_id; - const { cancelType, isCancelScheduled } = getState().system; + const { cancelType, isCancelScheduled, boardIdToAddTo } = + getState().system; // Handle scheduled cancelation if (cancelType === 'scheduled' && isCancelScheduled) { @@ -57,6 +59,15 @@ export const addInvocationCompleteEventListener = () => { dispatch(addImageToStagingArea(imageDTO)); } + if (boardIdToAddTo && !imageDTO.is_intermediate) { + dispatch( + api.endpoints.addImageToBoard.initiate({ + board_id: boardIdToAddTo, + image_name, + }) + ); + } + dispatch(progressImageSet(null)); } // pass along the socket event as an application action diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts index 7cb8012848..b9ddcea4c3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts @@ -22,15 +22,15 @@ const selectAllUsedImages = createSelector( selectImagesEntities, ], (generation, canvas, nodes, controlNet, imageEntities) => { - const allUsedImages: ImageDTO[] = []; + const allUsedImages: string[] = []; if (generation.initialImage) { - allUsedImages.push(generation.initialImage); + allUsedImages.push(generation.initialImage.imageName); } canvas.layerState.objects.forEach((obj) => { if (obj.kind === 'image') { - allUsedImages.push(obj.image); + allUsedImages.push(obj.imageName); } }); @@ -53,7 +53,7 @@ const selectAllUsedImages = createSelector( forEach(imageEntities, (image) => { if (image) { - allUsedImages.push(image); + allUsedImages.push(image.image_name); } }); @@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => { `Fetching new image URLs for ${allUsedImages.length} images` ); - allUsedImages.forEach(({ image_name }) => { + allUsedImages.forEach((image_name) => { dispatch( imageUrlsReceived({ imageName: image_name, diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index f577b73895..57a97168a3 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -5,40 +5,39 @@ import { configureStore, } from '@reduxjs/toolkit'; -import { rememberReducer, rememberEnhancer } from 'redux-remember'; import dynamicMiddlewares from 'redux-dynamic-middlewares'; +import { rememberEnhancer, rememberReducer } from 'redux-remember'; import canvasReducer from 'features/canvas/store/canvasSlice'; +import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; import imagesReducer from 'features/gallery/store/imagesSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import generationReducer from 'features/parameters/store/generationSlice'; -import controlNetReducer from 'features/controlNet/store/controlNetSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import systemReducer from 'features/system/store/systemSlice'; // import sessionReducer from 'features/system/store/sessionSlice'; -import configReducer from 'features/system/store/configSlice'; -import uiReducer from 'features/ui/store/uiSlice'; -import hotkeysReducer from 'features/ui/store/hotkeysSlice'; -import modelsReducer from 'features/system/store/modelSlice'; import nodesReducer from 'features/nodes/store/nodesSlice'; +import boardsReducer from 'features/gallery/store/boardSlice'; +import configReducer from 'features/system/store/configSlice'; +import hotkeysReducer from 'features/ui/store/hotkeysSlice'; +import uiReducer from 'features/ui/store/uiSlice'; import { listenerMiddleware } from './middleware/listenerMiddleware'; import { actionSanitizer } from './middleware/devtools/actionSanitizer'; -import { stateSanitizer } from './middleware/devtools/stateSanitizer'; import { actionsDenylist } from './middleware/devtools/actionsDenylist'; - +import { stateSanitizer } from './middleware/devtools/stateSanitizer'; +import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; -import { LOCALSTORAGE_PREFIX } from './constants'; +import { api } from 'services/apiSlice'; const allReducers = { canvas: canvasReducer, gallery: galleryReducer, generation: generationReducer, lightbox: lightboxReducer, - models: modelsReducer, nodes: nodesReducer, postprocessing: postprocessingReducer, system: systemReducer, @@ -47,7 +46,9 @@ const allReducers = { hotkeys: hotkeysReducer, images: imagesReducer, controlNet: controlNetReducer, + boards: boardsReducer, // session: sessionReducer, + [api.reducerPath]: api.reducer, }; const rootReducer = combineReducers(allReducers); @@ -59,12 +60,12 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'gallery', 'generation', 'lightbox', - // 'models', 'nodes', 'postprocessing', 'system', 'ui', 'controlNet', + // 'boards', // 'hotkeys', // 'config', ]; @@ -84,6 +85,7 @@ export const store = configureStore({ immutableCheck: false, serializableCheck: false, }) + .concat(api.middleware) .concat(dynamicMiddlewares) .prepend(listenerMiddleware.middleware), devTools: { diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx index 669a68c88a..e54b4a8872 100644 --- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx @@ -9,7 +9,7 @@ import { import { useDraggable, useDroppable } from '@dnd-kit/core'; import { useCombinedRefs } from '@dnd-kit/utilities'; import IAIIconButton from 'common/components/IAIIconButton'; -import { IAIImageFallback } from 'common/components/IAIImageFallback'; +import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import { AnimatePresence } from 'framer-motion'; import { ReactElement, SyntheticEvent, useCallback } from 'react'; @@ -53,7 +53,7 @@ const IAIDndImage = (props: IAIDndImageProps) => { isDropDisabled = false, isDragDisabled = false, isUploadDisabled = false, - fallback = , + fallback = , payloadImage, minSize = 24, postUploadAction, diff --git a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx index 3d34fbca9e..03a00d5b1c 100644 --- a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx +++ b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx @@ -1,10 +1,20 @@ -import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react'; +import { + As, + Flex, + FlexProps, + Icon, + IconProps, + Spinner, + SpinnerProps, +} from '@chakra-ui/react'; +import { ReactElement } from 'react'; +import { FaImage } from 'react-icons/fa'; type Props = FlexProps & { spinnerProps?: SpinnerProps; }; -export const IAIImageFallback = (props: Props) => { +export const IAIImageLoadingFallback = (props: Props) => { const { spinnerProps, ...rest } = props; const { sx, ...restFlexProps } = rest; return ( @@ -25,3 +35,35 @@ export const IAIImageFallback = (props: Props) => { ); }; + +type IAINoImageFallbackProps = { + flexProps?: FlexProps; + iconProps?: IconProps; + as?: As; +}; + +export const IAINoImageFallback = (props: IAINoImageFallbackProps) => { + const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} }; + const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} }; + return ( + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx index b8757eff0c..c3132f0285 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx @@ -1,14 +1,21 @@ -import { Image } from 'react-konva'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { Image, Rect } from 'react-konva'; +import { useGetImageDTOQuery } from 'services/apiSlice'; import useImage from 'use-image'; +import { CanvasImage } from '../store/canvasTypes'; type IAICanvasImageProps = { - url: string; - x: number; - y: number; + canvasImage: CanvasImage; }; const IAICanvasImage = (props: IAICanvasImageProps) => { - const { url, x, y } = props; - const [image] = useImage(url, 'anonymous'); + const { width, height, x, y, imageName } = props.canvasImage; + const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken); + const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous'); + + if (!imageDTO) { + return ; + } + return ; }; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx index ea04aa95c8..ec1e87cca7 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx @@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => { {objects.map((obj, i) => { if (isCanvasBaseImage(obj)) { - return ( - - ); + return ; } else if (isCanvasBaseLine(obj)) { const line = ( { return ( {shouldShowStagingImage && currentStagingAreaImage && ( - + )} {shouldShowStagingOutline && ( diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index b7092bf7e0..3e40c1211d 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -203,7 +203,7 @@ export const canvasSlice = createSlice({ y: 0, width: width, height: height, - image: image, + imageName: image.image_name, }, ], }; @@ -325,7 +325,7 @@ export const canvasSlice = createSlice({ kind: 'image', layer: 'base', ...state.layerState.stagingArea.boundingBox, - image, + imageName: image.image_name, }); state.layerState.stagingArea.selectedImageIndex = @@ -865,25 +865,25 @@ export const canvasSlice = createSlice({ state.doesCanvasNeedScaling = true; }); - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; + // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { + // const { image_name, image_url, thumbnail_url } = action.payload; - state.layerState.objects.forEach((object) => { - if (object.kind === 'image') { - if (object.image.image_name === image_name) { - object.image.image_url = image_url; - object.image.thumbnail_url = thumbnail_url; - } - } - }); + // state.layerState.objects.forEach((object) => { + // if (object.kind === 'image') { + // if (object.image.image_name === image_name) { + // object.image.image_url = image_url; + // object.image.thumbnail_url = thumbnail_url; + // } + // } + // }); - state.layerState.stagingArea.images.forEach((stagedImage) => { - if (stagedImage.image.image_name === image_name) { - stagedImage.image.image_url = image_url; - stagedImage.image.thumbnail_url = thumbnail_url; - } - }); - }); + // state.layerState.stagingArea.images.forEach((stagedImage) => { + // if (stagedImage.image.image_name === image_name) { + // stagedImage.image.image_url = image_url; + // stagedImage.image.thumbnail_url = thumbnail_url; + // } + // }); + // }); }, }); diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts index ae78287a7b..9294e10d32 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts @@ -38,7 +38,7 @@ export type CanvasImage = { y: number; width: number; height: number; - image: ImageDTO; + imageName: string; }; export type CanvasMaskLine = { diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index b8d8896dad..217caf9461 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -11,9 +11,11 @@ import IAIDndImage from 'common/components/IAIDndImage'; import { createSelector } from '@reduxjs/toolkit'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { AnimatePresence, motion } from 'framer-motion'; -import { IAIImageFallback } from 'common/components/IAIImageFallback'; +import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; import IAIIconButton from 'common/components/IAIIconButton'; import { FaUndo } from 'react-icons/fa'; +import { useGetImageDTOQuery } from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; const selector = createSelector( controlNetSelector, @@ -31,24 +33,45 @@ type Props = { const ControlNetImagePreview = (props: Props) => { const { imageSx } = props; - const { controlNetId, controlImage, processedControlImage, processorType } = - props.controlNet; + const { + controlNetId, + controlImage: controlImageName, + processedControlImage: processedControlImageName, + processorType, + } = props.controlNet; const dispatch = useAppDispatch(); const { pendingControlImages } = useAppSelector(selector); const [isMouseOverImage, setIsMouseOverImage] = useState(false); + const { + data: controlImage, + isLoading: isLoadingControlImage, + isError: isErrorControlImage, + isSuccess: isSuccessControlImage, + } = useGetImageDTOQuery(controlImageName ?? skipToken); + + const { + data: processedControlImage, + isLoading: isLoadingProcessedControlImage, + isError: isErrorProcessedControlImage, + isSuccess: isSuccessProcessedControlImage, + } = useGetImageDTOQuery(processedControlImageName ?? skipToken); + const handleDrop = useCallback( (droppedImage: ImageDTO) => { - if (controlImage?.image_name === droppedImage.image_name) { + if (controlImageName === droppedImage.image_name) { return; } setIsMouseOverImage(false); dispatch( - controlNetImageChanged({ controlNetId, controlImage: droppedImage }) + controlNetImageChanged({ + controlNetId, + controlImage: droppedImage.image_name, + }) ); }, - [controlImage, controlNetId, dispatch] + [controlImageName, controlNetId, dispatch] ); const handleResetControlImage = useCallback(() => { @@ -150,7 +173,7 @@ const ControlNetImagePreview = (props: Props) => { h: 'full', }} > - + )} {controlImage && ( diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index f1b62cd997..5a54bdcd74 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -39,8 +39,8 @@ export type ControlNetConfig = { weight: number; beginStepPct: number; endStepPct: number; - controlImage: ImageDTO | null; - processedControlImage: ImageDTO | null; + controlImage: string | null; + processedControlImage: string | null; processorType: ControlNetProcessorType; processorNode: RequiredControlNetProcessorNode; shouldAutoConfig: boolean; @@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({ }, controlNetAddedFromImage: ( state, - action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }> + action: PayloadAction<{ controlNetId: string; controlImage: string }> ) => { const { controlNetId, controlImage } = action.payload; state.controlNets[controlNetId] = { @@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({ state, action: PayloadAction<{ controlNetId: string; - controlImage: ImageDTO | null; + controlImage: string | null; }> ) => { const { controlNetId, controlImage } = action.payload; @@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({ state, action: PayloadAction<{ controlNetId: string; - processedControlImage: ImageDTO | null; + processedControlImage: string | null; }> ) => { const { controlNetId, processedControlImage } = action.payload; @@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({ // Preemptively remove the image from the gallery const { imageName } = action.meta.arg; forEach(state.controlNets, (c) => { - if (c.controlImage?.image_name === imageName) { + if (c.controlImage === imageName) { c.controlImage = null; c.processedControlImage = null; } - if (c.processedControlImage?.image_name === imageName) { + if (c.processedControlImage === imageName) { c.processedControlImage = null; } }); }); - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; + // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { + // const { image_name, image_url, thumbnail_url } = action.payload; - forEach(state.controlNets, (c) => { - if (c.controlImage?.image_name === image_name) { - c.controlImage.image_url = image_url; - c.controlImage.thumbnail_url = thumbnail_url; - } - if (c.processedControlImage?.image_name === image_name) { - c.processedControlImage.image_url = image_url; - c.processedControlImage.thumbnail_url = thumbnail_url; - } - }); - }); + // forEach(state.controlNets, (c) => { + // if (c.controlImage?.image_name === image_name) { + // c.controlImage.image_url = image_url; + // c.controlImage.thumbnail_url = thumbnail_url; + // } + // if (c.processedControlImage?.image_name === image_name) { + // c.processedControlImage.image_url = image_url; + // c.processedControlImage.thumbnail_url = thumbnail_url; + // } + // }); + // }); builder.addCase(appSocketInvocationError, (state, action) => { state.pendingControlImages = []; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/AddBoardButton.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/AddBoardButton.tsx new file mode 100644 index 0000000000..632cebcb33 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/AddBoardButton.tsx @@ -0,0 +1,27 @@ +import IAIButton from 'common/components/IAIButton'; +import { useCallback } from 'react'; +import { useCreateBoardMutation } from 'services/apiSlice'; + +const DEFAULT_BOARD_NAME = 'My Board'; + +const AddBoardButton = () => { + const [createBoard, { isLoading }] = useCreateBoardMutation(); + + const handleCreateBoard = useCallback(() => { + createBoard(DEFAULT_BOARD_NAME); + }, [createBoard]); + + return ( + + Add Board + + ); +}; + +export default AddBoardButton; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx new file mode 100644 index 0000000000..e506c88e2d --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx @@ -0,0 +1,93 @@ +import { Flex, Text } from '@chakra-ui/react'; +import { FaImages } from 'react-icons/fa'; +import { boardIdSelected } from '../../store/boardSlice'; +import { useDispatch } from 'react-redux'; +import { IAINoImageFallback } from 'common/components/IAIImageFallback'; +import { AnimatePresence } from 'framer-motion'; +import { SelectedItemOverlay } from '../SelectedItemOverlay'; +import { useCallback } from 'react'; +import { ImageDTO } from 'services/api'; +import { useRemoveImageFromBoardMutation } from 'services/apiSlice'; +import { useDroppable } from '@dnd-kit/core'; +import IAIDropOverlay from 'common/components/IAIDropOverlay'; + +const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => { + const dispatch = useDispatch(); + + const handleAllImagesBoardClick = () => { + dispatch(boardIdSelected()); + }; + + const [removeImageFromBoard, { isLoading }] = + useRemoveImageFromBoardMutation(); + + const handleDrop = useCallback( + (droppedImage: ImageDTO) => { + if (!droppedImage.board_id) { + return; + } + removeImageFromBoard({ + board_id: droppedImage.board_id, + image_name: droppedImage.image_name, + }); + }, + [removeImageFromBoard] + ); + + const { + isOver, + setNodeRef, + active: isDropActive, + } = useDroppable({ + id: `board_droppable_all_images`, + data: { + handleDrop, + }, + }); + + return ( + + + + + {isSelected && } + + + {isDropActive && } + + + + All Images + + + ); +}; + +export default AllImagesBoard; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx new file mode 100644 index 0000000000..738693a278 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx @@ -0,0 +1,134 @@ +import { + Collapse, + Flex, + Grid, + IconButton, + Input, + InputGroup, + InputRightElement, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { + boardsSelector, + setBoardSearchText, +} from 'features/gallery/store/boardSlice'; +import { memo, useState } from 'react'; +import HoverableBoard from './HoverableBoard'; +import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; +import AddBoardButton from './AddBoardButton'; +import AllImagesBoard from './AllImagesBoard'; +import { CloseIcon } from '@chakra-ui/icons'; +import { useListAllBoardsQuery } from 'services/apiSlice'; + +const selector = createSelector( + [boardsSelector], + (boardsState) => { + const { selectedBoardId, searchText } = boardsState; + return { selectedBoardId, searchText }; + }, + defaultSelectorOptions +); + +type Props = { + isOpen: boolean; +}; + +const BoardsList = (props: Props) => { + const { isOpen } = props; + const dispatch = useAppDispatch(); + const { selectedBoardId, searchText } = useAppSelector(selector); + + const { data: boards } = useListAllBoardsQuery(); + + const filteredBoards = searchText + ? boards?.filter((board) => + board.board_name.toLowerCase().includes(searchText.toLowerCase()) + ) + : boards; + + const [searchMode, setSearchMode] = useState(false); + + const handleBoardSearch = (searchTerm: string) => { + setSearchMode(searchTerm.length > 0); + dispatch(setBoardSearchText(searchTerm)); + }; + const clearBoardSearch = () => { + setSearchMode(false); + dispatch(setBoardSearchText('')); + }; + + return ( + + + + + { + handleBoardSearch(e.target.value); + }} + /> + {searchText && searchText.length && ( + + } + /> + + )} + + + + + + {!searchMode && } + {filteredBoards && + filteredBoards.map((board) => ( + + ))} + + + + + ); +}; + +export default memo(BoardsList); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx new file mode 100644 index 0000000000..a2c07e4870 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx @@ -0,0 +1,193 @@ +import { + Badge, + Box, + Editable, + EditableInput, + EditablePreview, + Flex, + Image, + MenuItem, + MenuList, +} from '@chakra-ui/react'; + +import { useAppDispatch } from 'app/store/storeHooks'; +import { memo, useCallback } from 'react'; +import { FaFolder, FaTrash } from 'react-icons/fa'; +import { ContextMenu } from 'chakra-ui-contextmenu'; +import { BoardDTO, ImageDTO } from 'services/api'; +import { IAINoImageFallback } from 'common/components/IAIImageFallback'; +import { boardIdSelected } from 'features/gallery/store/boardSlice'; +import { + useAddImageToBoardMutation, + useDeleteBoardMutation, + useGetImageDTOQuery, + useUpdateBoardMutation, +} from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; +import { useDroppable } from '@dnd-kit/core'; +import { AnimatePresence } from 'framer-motion'; +import IAIDropOverlay from 'common/components/IAIDropOverlay'; +import { SelectedItemOverlay } from '../SelectedItemOverlay'; + +interface HoverableBoardProps { + board: BoardDTO; + isSelected: boolean; +} + +const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => { + const dispatch = useAppDispatch(); + + const { data: coverImage } = useGetImageDTOQuery( + board.cover_image_name ?? skipToken + ); + + const { board_name, board_id } = board; + + const handleSelectBoard = useCallback(() => { + dispatch(boardIdSelected(board_id)); + }, [board_id, dispatch]); + + const [updateBoard, { isLoading: isUpdateBoardLoading }] = + useUpdateBoardMutation(); + + const [deleteBoard, { isLoading: isDeleteBoardLoading }] = + useDeleteBoardMutation(); + + const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] = + useAddImageToBoardMutation(); + + const handleUpdateBoardName = (newBoardName: string) => { + updateBoard({ board_id, changes: { board_name: newBoardName } }); + }; + + const handleDeleteBoard = useCallback(() => { + deleteBoard(board_id); + }, [board_id, deleteBoard]); + + const handleDrop = useCallback( + (droppedImage: ImageDTO) => { + if (droppedImage.board_id === board_id) { + return; + } + addImageToBoard({ board_id, image_name: droppedImage.image_name }); + }, + [addImageToBoard, board_id] + ); + + const { + isOver, + setNodeRef, + active: isDropActive, + } = useDroppable({ + id: `board_droppable_${board_id}`, + data: { + handleDrop, + }, + }); + + return ( + + + menuProps={{ size: 'sm', isLazy: true }} + renderMenu={() => ( + + } + onClickCapture={handleDeleteBoard} + > + Delete Board + + + )} + > + {(ref) => ( + + + {board.cover_image_name && coverImage?.image_url && ( + + )} + {!(board.cover_image_name && coverImage?.image_url) && ( + + )} + + {board.image_count} + + + {isSelected && } + + + {isDropActive && } + + + + + { + handleUpdateBoardName(nextValue); + }} + > + + + + + + )} + + + ); +}); + +HoverableBoard.displayName = 'HoverableBoard'; + +export default HoverableBoard; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx new file mode 100644 index 0000000000..b16bddd6b4 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx @@ -0,0 +1,93 @@ +import { + AlertDialog, + AlertDialogBody, + AlertDialogContent, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogOverlay, + Box, + Flex, + Spinner, + Text, +} from '@chakra-ui/react'; +import IAIButton from 'common/components/IAIButton'; + +import { memo, useContext, useRef, useState } from 'react'; +import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { useListAllBoardsQuery } from 'services/apiSlice'; + +const UpdateImageBoardModal = () => { + // const boards = useSelector(selectBoardsAll); + const { data: boards, isFetching } = useListAllBoardsQuery(); + const { isOpen, onClose, handleAddToBoard, image } = useContext( + AddImageToBoardContext + ); + const [selectedBoard, setSelectedBoard] = useState(); + + const cancelRef = useRef(null); + + const currentBoard = boards?.find( + (board) => board.board_id === image?.board_id + ); + + return ( + + + + + {currentBoard ? 'Move Image to Board' : 'Add Image to Board'} + + + + + + {currentBoard && ( + + Moving this image from{' '} + {currentBoard.board_name} to + + )} + {isFetching ? ( + + ) : ( + setSelectedBoard(v)} + value={selectedBoard} + data={(boards ?? []).map((board) => ({ + label: board.board_name, + value: board.board_id, + }))} + /> + )} + + + + + Cancel + { + if (selectedBoard) { + handleAddToBoard(selectedBoard); + } + }} + ml={3} + > + {currentBoard ? 'Move' : 'Add'} + + + + + + ); +}; + +export default memo(UpdateImageBoardModal); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx index a5eaeb4c71..169a965be0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx @@ -51,9 +51,12 @@ import { useAppToaster } from 'app/components/Toaster'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { DeleteImageContext } from 'app/contexts/DeleteImageContext'; import { DeleteImageButton } from './DeleteImageModal'; +import { selectImagesById } from '../store/imagesSlice'; +import { RootState } from 'app/store/store'; const currentImageButtonsSelector = createSelector( [ + (state: RootState) => state, systemSelector, gallerySelector, postprocessingSelector, @@ -61,7 +64,7 @@ const currentImageButtonsSelector = createSelector( lightboxSelector, activeTabNameSelector, ], - (system, gallery, postprocessing, ui, lightbox, activeTabName) => { + (state, system, gallery, postprocessing, ui, lightbox, activeTabName) => { const { isProcessing, isConnected, @@ -81,6 +84,8 @@ const currentImageButtonsSelector = createSelector( shouldShowProgressInViewer, } = ui; + const imageDTO = selectImagesById(state, gallery.selectedImage ?? ''); + const { selectedImage } = gallery; return { @@ -97,10 +102,10 @@ const currentImageButtonsSelector = createSelector( activeTabName, isLightboxOpen, shouldHidePreview, - image: selectedImage, - seed: selectedImage?.metadata?.seed, - prompt: selectedImage?.metadata?.positive_conditioning, - negativePrompt: selectedImage?.metadata?.negative_conditioning, + image: imageDTO, + seed: imageDTO?.metadata?.seed, + prompt: imageDTO?.metadata?.positive_conditioning, + negativePrompt: imageDTO?.metadata?.negative_conditioning, shouldShowProgressInViewer, }; }, diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx index c591206a27..5426fee3b1 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx @@ -9,12 +9,12 @@ import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; import NextPrevImageButtons from './NextPrevImageButtons'; import { memo, useCallback } from 'react'; import { systemSelector } from 'features/system/store/systemSelectors'; -import { configSelector } from '../../system/store/configSelectors'; -import { useAppToaster } from 'app/components/Toaster'; import { imageSelected } from '../store/gallerySlice'; import IAIDndImage from 'common/components/IAIDndImage'; import { ImageDTO } from 'services/api'; -import { IAIImageFallback } from 'common/components/IAIImageFallback'; +import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; +import { useGetImageDTOQuery } from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; export const imagesSelector = createSelector( [uiSelector, gallerySelector, systemSelector], @@ -29,7 +29,7 @@ export const imagesSelector = createSelector( return { shouldShowImageDetails, shouldHidePreview, - image: selectedImage, + selectedImage, progressImage, shouldShowProgressInViewer, shouldAntialiasProgressImage, @@ -45,11 +45,23 @@ export const imagesSelector = createSelector( const CurrentImagePreview = () => { const { shouldShowImageDetails, - image, + selectedImage, progressImage, shouldShowProgressInViewer, shouldAntialiasProgressImage, } = useAppSelector(imagesSelector); + + // const image = useAppSelector((state: RootState) => + // selectImagesById(state, selectedImage ?? '') + // ); + + const { + data: image, + isLoading, + isError, + isSuccess, + } = useGetImageDTOQuery(selectedImage ?? skipToken); + const dispatch = useAppDispatch(); const handleDrop = useCallback( @@ -57,7 +69,7 @@ const CurrentImagePreview = () => { if (droppedImage.image_name === image?.image_name) { return; } - dispatch(imageSelected(droppedImage)); + dispatch(imageSelected(droppedImage.image_name)); }, [dispatch, image?.image_name] ); @@ -98,14 +110,14 @@ const CurrentImagePreview = () => { }} > } + fallback={} isUploadDisabled={true} /> )} - {shouldShowImageDetails && image && ( + {shouldShowImageDetails && image && selectedImage && ( { )} - {!shouldShowImageDetails && image && ( + {!shouldShowImageDetails && image && selectedImage && ( - prev.image.image_name === next.image.image_name && - prev.isSelected === next.isSelected; - /** * Gallery image component with delete/use all/use seed buttons on hover. */ -const HoverableImage = memo((props: HoverableImageProps) => { +const HoverableImage = (props: HoverableImageProps) => { const dispatch = useAppDispatch(); const { activeTabName, @@ -93,6 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => { const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const { onDelete } = useContext(DeleteImageContext); + const { onClickAddToBoard } = useContext(AddImageToBoardContext); const handleDelete = useCallback(() => { onDelete(image); }, [image, onDelete]); @@ -106,11 +109,13 @@ const HoverableImage = memo((props: HoverableImageProps) => { }, }); + const [removeFromBoard] = useRemoveImageFromBoardMutation(); + const handleMouseOver = () => setIsHovered(true); const handleMouseOut = () => setIsHovered(false); const handleSelectImage = useCallback(() => { - dispatch(imageSelected(image)); + dispatch(imageSelected(image.image_name)); }, [image, dispatch]); // Recall parameters handlers @@ -168,6 +173,17 @@ const HoverableImage = memo((props: HoverableImageProps) => { // dispatch(setIsLightboxOpen(true)); }; + const handleAddToBoard = useCallback(() => { + onClickAddToBoard(image); + }, [image, onClickAddToBoard]); + + const handleRemoveFromBoard = useCallback(() => { + if (!image.board_id) { + return; + } + removeFromBoard({ board_id: image.board_id, image_name: image.image_name }); + }, [image.board_id, image.image_name, removeFromBoard]); + const handleOpenInNewTab = () => { window.open(image.image_url, '_blank'); }; @@ -244,6 +260,17 @@ const HoverableImage = memo((props: HoverableImageProps) => { {t('parameters.sendToUnifiedCanvas')} )} + } onClickCapture={handleAddToBoard}> + {image.board_id ? 'Change Board' : 'Add to Board'} + + {image.board_id && ( + } + onClickCapture={handleRemoveFromBoard} + > + Remove from Board + + )} } @@ -339,8 +366,6 @@ const HoverableImage = memo((props: HoverableImageProps) => { ); -}, memoEqualityCheck); +}; -HoverableImage.displayName = 'HoverableImage'; - -export default HoverableImage; +export default memo(HoverableImage); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index fe8690e379..46f2378ae0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -1,12 +1,15 @@ import { Box, + Button, ButtonGroup, Flex, FlexProps, Grid, Icon, Text, + VStack, forwardRef, + useDisclosure, } from '@chakra-ui/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; @@ -20,6 +23,7 @@ import { setGalleryImageObjectFit, setShouldAutoSwitchToNewImages, setShouldUseSingleGalleryColumn, + setGalleryView, } from 'features/gallery/store/gallerySlice'; import { togglePinGalleryPanel } from 'features/ui/store/uiSlice'; import { useOverlayScrollbars } from 'overlayscrollbars-react'; @@ -53,41 +57,51 @@ import { selectImagesAll, } from '../store/imagesSlice'; import { receivedPageOfImages } from 'services/thunks/image'; +import BoardsList from './Boards/BoardsList'; +import { boardsSelector } from '../store/boardSlice'; +import { ChevronUpIcon } from '@chakra-ui/icons'; +import { useListAllBoardsQuery } from 'services/apiSlice'; -const categorySelector = createSelector( +const itemSelector = createSelector( [(state: RootState) => state], (state) => { - const { images } = state; - const { categories } = images; + const { categories, total: allImagesTotal, isLoading } = state.images; + const { selectedBoardId } = state.boards; const allImages = selectImagesAll(state); - const filteredImages = allImages.filter((i) => - categories.includes(i.image_category) - ); + + const images = allImages.filter((i) => { + const isInCategory = categories.includes(i.image_category); + const isInSelectedBoard = selectedBoardId + ? i.board_id === selectedBoardId + : true; + return isInCategory && isInSelectedBoard; + }); return { - images: filteredImages, - isLoading: images.isLoading, - areMoreImagesAvailable: filteredImages.length < images.total, - categories: images.categories, + images, + allImagesTotal, + isLoading, + categories, + selectedBoardId, }; }, defaultSelectorOptions ); const mainSelector = createSelector( - [gallerySelector, uiSelector], - (gallery, ui) => { + [gallerySelector, uiSelector, boardsSelector], + (gallery, ui, boards) => { const { galleryImageMinimumWidth, galleryImageObjectFit, shouldAutoSwitchToNewImages, shouldUseSingleGalleryColumn, selectedImage, + galleryView, } = gallery; const { shouldPinGallery } = ui; - return { shouldPinGallery, galleryImageMinimumWidth, @@ -95,6 +109,8 @@ const mainSelector = createSelector( shouldAutoSwitchToNewImages, shouldUseSingleGalleryColumn, selectedImage, + galleryView, + selectedBoardId: boards.selectedBoardId, }; }, defaultSelectorOptions @@ -126,21 +142,44 @@ const ImageGalleryContent = () => { shouldAutoSwitchToNewImages, shouldUseSingleGalleryColumn, selectedImage, + galleryView, } = useAppSelector(mainSelector); - const { images, areMoreImagesAvailable, isLoading, categories } = - useAppSelector(categorySelector); + const { images, isLoading, allImagesTotal, categories, selectedBoardId } = + useAppSelector(itemSelector); + + const { selectedBoard } = useListAllBoardsQuery(undefined, { + selectFromResult: ({ data }) => ({ + selectedBoard: data?.find((b) => b.board_id === selectedBoardId), + }), + }); + + const filteredImagesTotal = useMemo( + () => selectedBoard?.image_count ?? allImagesTotal, + [allImagesTotal, selectedBoard?.image_count] + ); + + const areMoreAvailable = useMemo(() => { + return images.length < filteredImagesTotal; + }, [images.length, filteredImagesTotal]); const handleLoadMoreImages = useCallback(() => { - dispatch(receivedPageOfImages()); - }, [dispatch]); + dispatch( + receivedPageOfImages({ + categories, + boardId: selectedBoardId, + }) + ); + }, [categories, dispatch, selectedBoardId]); const handleEndReached = useMemo(() => { - if (areMoreImagesAvailable && !isLoading) { + if (areMoreAvailable && !isLoading) { return handleLoadMoreImages; } return undefined; - }, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]); + }, [areMoreAvailable, handleLoadMoreImages, isLoading]); + + const { isOpen: isBoardListOpen, onToggle } = useDisclosure(); const handleChangeGalleryImageMinimumWidth = (v: number) => { dispatch(setGalleryImageMinimumWidth(v)); @@ -172,46 +211,79 @@ const ImageGalleryContent = () => { const handleClickImagesCategory = useCallback(() => { dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); + dispatch(setGalleryView('images')); }, [dispatch]); const handleClickAssetsCategory = useCallback(() => { dispatch(imageCategoriesChanged(ASSETS_CATEGORIES)); + dispatch(setGalleryView('assets')); }, [dispatch]); return ( - - - - + + + } + /> + } + /> + + } - /> - } - /> - - + variant="ghost" + sx={{ + w: 'full', + justifyContent: 'center', + alignItems: 'center', + px: 2, + _hover: { + bg: 'base.800', + }, + }} + > + + {selectedBoard ? selectedBoard.board_name : 'All Images'} + + + { icon={shouldPinGallery ? : } /> - - - {images.length || areMoreImagesAvailable ? ( + + + + + + {images.length || areMoreAvailable ? ( <> {shouldUseSingleGalleryColumn ? ( @@ -280,14 +355,12 @@ const ImageGalleryContent = () => { data={images} endReached={handleEndReached} scrollerRef={(ref) => setScrollerRef(ref)} - itemContent={(index, image) => ( + itemContent={(index, item) => ( )} @@ -302,13 +375,11 @@ const ImageGalleryContent = () => { List: ListContainer, }} scrollerRef={setScroller} - itemContent={(index, image) => ( + itemContent={(index, item) => ( )} /> @@ -316,12 +387,12 @@ const ImageGalleryContent = () => { - {areMoreImagesAvailable + {areMoreAvailable ? t('gallery.loadMore') : t('gallery.allImagesLoaded')} @@ -350,7 +421,7 @@ const ImageGalleryContent = () => { )} - + ); }; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx index 892516a3cc..e5cb4cf4a8 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx @@ -93,19 +93,11 @@ type ImageMetadataViewerProps = { image: ImageDTO; }; -// TODO: I don't know if this is needed. -const memoEqualityCheck = ( - prev: ImageMetadataViewerProps, - next: ImageMetadataViewerProps -) => prev.image.image_name === next.image.image_name; - -// TODO: Show more interesting information in this component. - /** * Image metadata viewer overlays currently selected image and provides * access to any of its metadata for use in processing. */ -const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { +const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => { const dispatch = useAppDispatch(); const { recallBothPrompts, @@ -333,8 +325,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { ); -}, memoEqualityCheck); +}; -ImageMetadataViewer.displayName = 'ImageMetadataViewer'; - -export default ImageMetadataViewer; +export default memo(ImageMetadataViewer); diff --git a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx index 82e7a0d623..b1f06ad433 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx @@ -42,7 +42,7 @@ export const nextPrevImageButtonsSelector = createSelector( } const currentImageIndex = filteredImageIds.findIndex( - (i) => i === selectedImage.image_name + (i) => i === selectedImage ); const nextImageIndex = clamp( @@ -71,6 +71,8 @@ export const nextPrevImageButtonsSelector = createSelector( !isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1, nextImage, prevImage, + nextImageId, + prevImageId, }; }, { @@ -84,7 +86,7 @@ const NextPrevImageButtons = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { isOnFirstImage, isOnLastImage, nextImage, prevImage } = + const { isOnFirstImage, isOnLastImage, nextImageId, prevImageId } = useAppSelector(nextPrevImageButtonsSelector); const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = @@ -99,19 +101,19 @@ const NextPrevImageButtons = () => { }, []); const handlePrevImage = useCallback(() => { - dispatch(imageSelected(prevImage)); - }, [dispatch, prevImage]); + dispatch(imageSelected(prevImageId)); + }, [dispatch, prevImageId]); const handleNextImage = useCallback(() => { - dispatch(imageSelected(nextImage)); - }, [dispatch, nextImage]); + dispatch(imageSelected(nextImageId)); + }, [dispatch, nextImageId]); useHotkeys( 'left', () => { handlePrevImage(); }, - [prevImage] + [prevImageId] ); useHotkeys( @@ -119,7 +121,7 @@ const NextPrevImageButtons = () => { () => { handleNextImage(); }, - [nextImage] + [nextImageId] ); return ( diff --git a/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx b/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx new file mode 100644 index 0000000000..7038b4b64f --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx @@ -0,0 +1,26 @@ +import { motion } from 'framer-motion'; + +export const SelectedItemOverlay = () => ( + +); diff --git a/invokeai/frontend/web/src/features/gallery/store/boardSelectors.ts b/invokeai/frontend/web/src/features/gallery/store/boardSelectors.ts new file mode 100644 index 0000000000..3dac2b6e50 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/store/boardSelectors.ts @@ -0,0 +1,23 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { selectBoardsAll } from './boardSlice'; + +export const boardSelector = (state: RootState) => state.boards.entities; + +export const searchBoardsSelector = createSelector( + (state: RootState) => state, + (state) => { + const { + boards: { searchText }, + } = state; + + if (!searchText) { + // If no search text provided, return all entities + return selectBoardsAll(state); + } + + return selectBoardsAll(state).filter((i) => + i.board_name.toLowerCase().includes(searchText.toLowerCase()) + ); + } +); diff --git a/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts b/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts new file mode 100644 index 0000000000..8fc9bfa486 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts @@ -0,0 +1,47 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { api } from 'services/apiSlice'; + +type BoardsState = { + searchText: string; + selectedBoardId?: string; + updateBoardModalOpen: boolean; +}; + +export const initialBoardsState: BoardsState = { + updateBoardModalOpen: false, + searchText: '', +}; + +const boardsSlice = createSlice({ + name: 'boards', + initialState: initialBoardsState, + reducers: { + boardIdSelected: (state, action: PayloadAction) => { + state.selectedBoardId = action.payload; + }, + setBoardSearchText: (state, action: PayloadAction) => { + state.searchText = action.payload; + }, + setUpdateBoardModalOpen: (state, action: PayloadAction) => { + state.updateBoardModalOpen = action.payload; + }, + }, + extraReducers: (builder) => { + builder.addMatcher( + api.endpoints.deleteBoard.matchFulfilled, + (state, action) => { + if (action.meta.arg.originalArgs === state.selectedBoardId) { + state.selectedBoardId = undefined; + } + } + ); + }, +}); + +export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } = + boardsSlice.actions; + +export const boardsSelector = (state: RootState) => state.boards; + +export default boardsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index 4f250a7c3a..b7fc0809a6 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -1,17 +1,16 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; -import { ImageDTO } from 'services/api'; import { imageUpserted } from './imagesSlice'; -import { imageUrlsReceived } from 'services/thunks/image'; type GalleryImageObjectFitType = 'contain' | 'cover'; export interface GalleryState { - selectedImage?: ImageDTO; + selectedImage?: string; galleryImageMinimumWidth: number; galleryImageObjectFit: GalleryImageObjectFitType; shouldAutoSwitchToNewImages: boolean; shouldUseSingleGalleryColumn: boolean; + galleryView: 'images' | 'assets' | 'boards'; } export const initialGalleryState: GalleryState = { @@ -19,13 +18,14 @@ export const initialGalleryState: GalleryState = { galleryImageObjectFit: 'cover', shouldAutoSwitchToNewImages: true, shouldUseSingleGalleryColumn: false, + galleryView: 'images', }; export const gallerySlice = createSlice({ name: 'gallery', initialState: initialGalleryState, reducers: { - imageSelected: (state, action: PayloadAction) => { + imageSelected: (state, action: PayloadAction) => { state.selectedImage = action.payload; // TODO: if the user selects an image, disable the auto switch? // state.shouldAutoSwitchToNewImages = false; @@ -48,6 +48,12 @@ export const gallerySlice = createSlice({ ) => { state.shouldUseSingleGalleryColumn = action.payload; }, + setGalleryView: ( + state, + action: PayloadAction<'images' | 'assets' | 'boards'> + ) => { + state.galleryView = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(imageUpserted, (state, action) => { @@ -55,17 +61,17 @@ export const gallerySlice = createSlice({ state.shouldAutoSwitchToNewImages && action.payload.image_category === 'general' ) { - state.selectedImage = action.payload; + state.selectedImage = action.payload.image_name; } }); - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; + // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { + // const { image_name, image_url, thumbnail_url } = action.payload; - if (state.selectedImage?.image_name === image_name) { - state.selectedImage.image_url = image_url; - state.selectedImage.thumbnail_url = thumbnail_url; - } - }); + // if (state.selectedImage?.image_name === image_name) { + // state.selectedImage.image_url = image_url; + // state.selectedImage.thumbnail_url = thumbnail_url; + // } + // }); }, }); @@ -75,6 +81,7 @@ export const { setGalleryImageObjectFit, setShouldAutoSwitchToNewImages, setShouldUseSingleGalleryColumn, + setGalleryView, } = gallerySlice.actions; export default gallerySlice.reducer; diff --git a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts index 9c18380c54..25a3341532 100644 --- a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts @@ -11,7 +11,6 @@ import { dateComparator } from 'common/util/dateComparator'; import { keyBy } from 'lodash-es'; import { imageDeleted, - imageMetadataReceived, imageUrlsReceived, receivedPageOfImages, } from 'services/thunks/image'; @@ -74,11 +73,21 @@ const imagesSlice = createSlice({ }); builder.addCase(receivedPageOfImages.fulfilled, (state, action) => { state.isLoading = false; + const { boardId, categories, imageOrigin, isIntermediate } = + action.meta.arg; + const { items, offset, limit, total } = action.payload; + imagesAdapter.upsertMany(state, items); + + if (!categories?.includes('general') || boardId) { + // need to skip updating the total images count if the images recieved were for a specific board + // TODO: this doesn't work when on the Asset tab/category... + return; + } + state.offset = offset; state.limit = limit; state.total = total; - imagesAdapter.upsertMany(state, items); }); builder.addCase(imageDeleted.pending, (state, action) => { // Image deleted @@ -154,3 +163,16 @@ export const selectFilteredImagesIds = createSelector( .map((i) => i.image_name); } ); + +// export const selectImageById = createSelector( +// (state: RootState, imageId) => state, +// (state) => { +// const { +// images: { categories }, +// } = state; + +// return selectImagesAll(state) +// .filter((i) => categories.includes(i.image_category)) +// .map((i) => i.image_name); +// } +// ); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx index dc4590e6ca..c5a3a1970b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx @@ -11,6 +11,8 @@ import { FieldComponentProps } from './types'; import IAIDndImage from 'common/components/IAIDndImage'; import { ImageDTO } from 'services/api'; import { Flex } from '@chakra-ui/react'; +import { useGetImageDTOQuery } from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; const ImageInputFieldComponent = ( props: FieldComponentProps @@ -19,9 +21,16 @@ const ImageInputFieldComponent = ( const dispatch = useAppDispatch(); + const { + data: image, + isLoading, + isError, + isSuccess, + } = useGetImageDTOQuery(field.value ?? skipToken); + const handleDrop = useCallback( (droppedImage: ImageDTO) => { - if (field.value?.image_name === droppedImage.image_name) { + if (field.value === droppedImage.image_name) { return; } @@ -29,11 +38,11 @@ const ImageInputFieldComponent = ( fieldValueChanged({ nodeId, fieldName: field.name, - value: droppedImage, + value: droppedImage.image_name, }) ); }, - [dispatch, field.name, field.value?.image_name, nodeId] + [dispatch, field.name, field.value, nodeId] ); const handleReset = useCallback(() => { @@ -56,7 +65,7 @@ const ImageInputFieldComponent = ( }} > { - return { allModelNames }; - // return map(modelList, (_, name) => name); - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); +import { memo, useCallback, useEffect, useMemo } from 'react'; +import { FieldComponentProps } from './types'; +import { forEach, isString } from 'lodash-es'; +import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; +import { useTranslation } from 'react-i18next'; +import { useListModelsQuery } from 'services/apiSlice'; const ModelInputFieldComponent = ( props: FieldComponentProps @@ -30,28 +20,82 @@ const ModelInputFieldComponent = ( const { nodeId, field } = props; const dispatch = useAppDispatch(); + const { t } = useTranslation(); - const { allModelNames } = useAppSelector(availableModelsSelector); + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); - const handleValueChanged = (e: ChangeEvent) => { - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: e.target.value, - }) - ); - }; + const data = useMemo(() => { + if (!pipelineModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(pipelineModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [pipelineModels]); + + const selectedModel = useMemo( + () => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]], + [pipelineModels?.entities, pipelineModels?.ids, field.value] + ); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && pipelineModels?.ids.includes(field.value)) { + return; + } + + const firstModel = pipelineModels?.ids[0]; + + if (!isString(firstModel)) { + return; + } + + handleValueChanged(firstModel); + }, [field.value, handleValueChanged, pipelineModels?.ids]); return ( - + /> ); }; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 5425d1cfd5..341f0c467b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -101,21 +101,6 @@ const nodesSlice = createSlice({ builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { state.schema = action.payload; }); - - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; - - state.nodes.forEach((node) => { - forEach(node.data.inputs, (input) => { - if (input.type === 'image') { - if (input.value?.image_name === image_name) { - input.value.image_url = image_url; - input.value.thumbnail_url = thumbnail_url; - } - } - }); - }); - }); }, }); diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 5e140b6eef..acad10cf48 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -214,7 +214,7 @@ export type VaeInputFieldValue = FieldValueBase & { export type ImageInputFieldValue = FieldValueBase & { type: 'image'; - value?: ImageDTO; + value?: string; }; export type ModelInputFieldValue = FieldValueBase & { diff --git a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts index dd5a97e2f1..314af85193 100644 --- a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts @@ -65,15 +65,13 @@ export const addControlNetToLinearGraph = ( if (processedControlImage && processorType !== 'none') { // We've already processed the image in the app, so we can just use the processed image - const { image_name } = processedControlImage; controlNetNode.image = { - image_name, + image_name: processedControlImage, }; } else if (controlImage) { // The control image is preprocessed - const { image_name } = controlImage; controlNetNode.image = { - image_name, + image_name: controlImage, }; } else { // Skip ControlNets without an unprocessed image - should never happen if everything is working correctly diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index efaeaddff2..ccdc3e0a27 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -23,6 +23,7 @@ import { } from './constants'; import { set } from 'lodash-es'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = ( id: NOISE, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 785e1d2fdb..9ffe85b3c9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -17,6 +17,7 @@ import { INPAINT_GRAPH, INPAINT, } from './constants'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = ( // We may need to set the inpaint width and height to scale the image const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; + const model = modelIdToPipelineModelField(modelId); + const graph: NonNullableGraph = { id: INPAINT_GRAPH, nodes: { @@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = ( prompt: negativePrompt, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [RANGE_OF_SIZE]: { type: 'range_of_size', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index ca0e56e849..920cb5bf02 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -14,6 +14,7 @@ import { TEXT_TO_LATENTS, } from './constants'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; /** * Builds the Canvas tab's Text to Image graph. @@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = ( // The bounding box determines width and height, not the width and height params const { width, height } = state.canvas.boundingBoxDimensions; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = ( steps, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index 1f2c8327e0..8425ac043a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -22,6 +22,7 @@ import { } from './constants'; import { set } from 'lodash-es'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; const moduleLog = log.child({ namespace: 'nodes' }); @@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = ( throw new Error('No initial image found in state'); } + const model = modelIdToPipelineModelField(modelId); + // copy-pasted graph from node editor, filled in with state values & friendly node ids const graph: NonNullableGraph = { id: IMAGE_TO_IMAGE_GRAPH, @@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = ( id: NOISE, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', @@ -274,7 +277,7 @@ export const buildLinearImageToImageGraph = ( id: RESIZE, type: 'img_resize', image: { - image_name: initialImage.image_name, + image_name: initialImage.imageName, }, is_intermediate: true, width, @@ -311,7 +314,7 @@ export const buildLinearImageToImageGraph = ( } else { // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly set(graph.nodes[IMAGE_TO_LATENTS], 'image', { - image_name: initialImage.image_name, + image_name: initialImage.imageName, }); // Pass the image's dimensions to the `NOISE` node diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index c179a89504..973acdfb77 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -1,6 +1,10 @@ import { RootState } from 'app/store/store'; import { NonNullableGraph } from 'features/nodes/types/types'; -import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api'; +import { + BaseModelType, + RandomIntInvocation, + RangeOfSizeInvocation, +} from 'services/api'; import { ITERATE, LATENTS_TO_IMAGE, @@ -14,6 +18,7 @@ import { TEXT_TO_LATENTS, } from './constants'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; type TextToImageGraphOverrides = { width: number; @@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = ( const { positivePrompt, negativePrompt, - model: model_name, + model: modelId, cfgScale: cfg_scale, scheduler, steps, @@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = ( shouldRandomizeSeed, } = state.generation; + const model = modelIdToPipelineModelField(modelId); + /** * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * full graph here as a template. Then use the parameters from app state and set friendlier node @@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = ( steps, }, [MODEL_LOADER]: { - type: 'sd1_model_loader', + type: 'pipeline_model_loader', id: MODEL_LOADER, - model_name, + model, }, [LATENTS_TO_IMAGE]: { type: 'l2i', diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 6a700d4813..072b1a53fd 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,9 +1,10 @@ import { Graph } from 'services/api'; import { v4 as uuidv4 } from 'uuid'; -import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es'; +import { cloneDeep, omit, reduce } from 'lodash-es'; import { RootState } from 'app/store/store'; import { InputFieldValue } from 'features/nodes/types/types'; import { AnyInvocation } from 'services/events/types'; +import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; /** * We need to do special handling for some fields @@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => { } } + if (field.type === 'model') { + if (field.value) { + return modelIdToPipelineModelField(field.value); + } + } + return field.value; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index 39e0080d11..7d4469bc41 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -7,7 +7,7 @@ export const NOISE = 'noise'; export const RANDOM_INT = 'rand_int'; export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; -export const MODEL_LOADER = 'model_loader'; +export const MODEL_LOADER = 'pipeline_model_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts new file mode 100644 index 0000000000..bbcd8d9bc6 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts @@ -0,0 +1,18 @@ +import { BaseModelType, PipelineModelField } from 'services/api'; + +/** + * Crudely converts a model id to a pipeline model field + * TODO: Make better + */ +export const modelIdToPipelineModelField = ( + modelId: string +): PipelineModelField => { + const [base_model, model_type, model_name] = modelId.split('/'); + + const field: PipelineModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts index e29b46af70..6ebd014876 100644 --- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts @@ -57,7 +57,7 @@ export const buildImg2ImgNode = ( } imageToImageNode.image = { - image_name: initialImage.image_name, + image_name: initialImage.imageName, }; } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx index 65da89b94d..5092893eed 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx @@ -6,7 +6,7 @@ import ParamScheduler from './ParamScheduler'; const ParamSchedulerAndModel = () => { return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx index fa415074e6..fbfa00e2a1 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx @@ -10,7 +10,9 @@ import { generationSelector } from 'features/parameters/store/generationSelector import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDndImage from 'common/components/IAIDndImage'; import { ImageDTO } from 'services/api'; -import { IAIImageFallback } from 'common/components/IAIImageFallback'; +import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback'; +import { useGetImageDTOQuery } from 'services/apiSlice'; +import { skipToken } from '@reduxjs/toolkit/dist/query'; const selector = createSelector( [generationSelector], @@ -27,14 +29,21 @@ const InitialImagePreview = () => { const { initialImage } = useAppSelector(selector); const dispatch = useAppDispatch(); + const { + data: image, + isLoading, + isError, + isSuccess, + } = useGetImageDTOQuery(initialImage?.imageName ?? skipToken); + const handleDrop = useCallback( (droppedImage: ImageDTO) => { - if (droppedImage.image_name === initialImage?.image_name) { + if (droppedImage.image_name === initialImage?.imageName) { return; } dispatch(initialImageChanged(droppedImage)); }, - [dispatch, initialImage?.image_name] + [dispatch, initialImage] ); const handleReset = useCallback(() => { @@ -53,10 +62,10 @@ const InitialImagePreview = () => { }} > } + fallback={} postUploadAction={{ type: 'SET_INITIAL_IMAGE' }} withResetIcon /> diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 961ea1b8af..e7dcbf0d83 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -1,10 +1,9 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; +import { DEFAULT_SCHEDULER_NAME } from 'app/constants'; import { configChanged } from 'features/system/store/configSlice'; -import { clamp, sortBy } from 'lodash-es'; +import { clamp } from 'lodash-es'; import { ImageDTO } from 'services/api'; -import { imageUrlsReceived } from 'services/thunks/image'; -import { receivedModels } from 'services/thunks/model'; import { CfgScaleParam, HeightParam, @@ -17,14 +16,13 @@ import { StrengthParam, WidthParam, } from './parameterZodSchemas'; -import { DEFAULT_SCHEDULER_NAME } from 'app/constants'; export interface GenerationState { cfgScale: CfgScaleParam; height: HeightParam; img2imgStrength: StrengthParam; infillMethod: string; - initialImage?: ImageDTO; + initialImage?: { imageName: string; width: number; height: number }; iterations: number; perlin: number; positivePrompt: PositivePromptParam; @@ -212,35 +210,20 @@ export const generationSlice = createSlice({ state.shouldUseNoiseSettings = action.payload; }, initialImageChanged: (state, action: PayloadAction) => { - state.initialImage = action.payload; + const { image_name, width, height } = action.payload; + state.initialImage = { imageName: image_name, width, height }; }, modelSelected: (state, action: PayloadAction) => { state.model = action.payload; }, }, extraReducers: (builder) => { - builder.addCase(receivedModels.fulfilled, (state, action) => { - if (!state.model) { - const firstModel = sortBy(action.payload, 'name')[0]; - state.model = firstModel.name; - } - }); - builder.addCase(configChanged, (state, action) => { const defaultModel = action.payload.sd?.defaultModel; if (defaultModel && !state.model) { state.model = defaultModel; } }); - - builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { - const { image_name, image_url, thumbnail_url } = action.payload; - - if (state.initialImage?.image_name === image_name) { - state.initialImage.image_url = image_url; - state.initialImage.thumbnail_url = thumbnail_url; - } - }); }, }); diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts index 61567d3fb8..48eb309e7d 100644 --- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts @@ -154,3 +154,17 @@ export type StrengthParam = z.infer; */ export const isValidStrength = (val: unknown): val is StrengthParam => zStrength.safeParse(val).success; + +// /** +// * Zod schema for BaseModelType +// */ +// export const zBaseModelType = z.enum(['sd-1', 'sd-2']); +// /** +// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI. +// */ +// export type BaseModelType = z.infer; +// /** +// * Validates/type-guards a value as a base model type +// */ +// export const isValidBaseModelType = (val: unknown): val is BaseModelType => +// zBaseModelType.safeParse(val).success; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index a38ab150dd..43de14d507 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -1,44 +1,59 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { isEqual } from 'lodash-es'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIMantineSelect, { - IAISelectDataType, -} from 'common/components/IAIMantineSelect'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; +import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { modelSelected } from 'features/parameters/store/generationSlice'; -import { selectModelsAll, selectModelsById } from '../store/modelSlice'; -const selector = createSelector( - [(state: RootState) => state, generationSelector], - (state, generation) => { - const selectedModel = selectModelsById(state, generation.model); +import { forEach, isString } from 'lodash-es'; +import { SelectItem } from '@mantine/core'; +import { RootState } from 'app/store/store'; +import { useListModelsQuery } from 'services/apiSlice'; - const modelData = selectModelsAll(state) - .map((m) => ({ - value: m.name, - label: m.name, - })) - .sort((a, b) => a.label.localeCompare(b.label)); - return { - selectedModel, - modelData, - }; - }, - { - memoizeOptions: { - resultEqualityCheck: isEqual, - }, - } -); +export const MODEL_TYPE_MAP = { + 'sd-1': 'Stable Diffusion 1.x', + 'sd-2': 'Stable Diffusion 2.x', +}; const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { selectedModel, modelData } = useAppSelector(selector); + + const selectedModelId = useAppSelector( + (state: RootState) => state.generation.model + ); + + const { data: pipelineModels } = useListModelsQuery({ + model_type: 'pipeline', + }); + + const data = useMemo(() => { + if (!pipelineModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(pipelineModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: MODEL_TYPE_MAP[model.base_model], + }); + }); + + return data; + }, [pipelineModels]); + + const selectedModel = useMemo( + () => pipelineModels?.entities[selectedModelId], + [pipelineModels?.entities, selectedModelId] + ); + const handleChangeModel = useCallback( (v: string | null) => { if (!v) { @@ -49,13 +64,27 @@ const ModelSelect = () => { [dispatch] ); + useEffect(() => { + if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) { + return; + } + + const firstModel = pipelineModels?.ids[0]; + + if (!isString(firstModel)) { + return; + } + + handleChangeModel(firstModel); + }, [handleChangeModel, pipelineModels?.ids, selectedModelId]); + return ( ); diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index 2e0b3234c7..26c11604e1 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -1,6 +1,5 @@ import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { RootState } from 'app/store/store'; - import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; @@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({ export default function SettingsSchedulers() { const dispatch = useAppDispatch(); + const { t } = useTranslation(); const enabledSchedulers = useAppSelector( diff --git a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts index 193420e29c..8ba5731a5b 100644 --- a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts +++ b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts @@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors'; const isApplicationReadySelector = createSelector( [systemSelector, configSelector], (system, config) => { - const { wereModelsReceived, wasSchemaParsed } = system; + const { wasSchemaParsed } = system; const { disabledTabs } = config; return { disabledTabs, - wereModelsReceived, wasSchemaParsed, }; } @@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector( * Checks if the application is ready to be used, i.e. if the initial startup process is finished. */ export const useIsApplicationReady = () => { - const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector( + const { disabledTabs, wasSchemaParsed } = useAppSelector( isApplicationReadySelector ); const isApplicationReady = useMemo(() => { - if (!wereModelsReceived) { - return false; - } - if (!disabledTabs.includes('nodes') && !wasSchemaParsed) { return false; } return true; - }, [disabledTabs, wereModelsReceived, wasSchemaParsed]); + }, [disabledTabs, wasSchemaParsed]); return isApplicationReady; }; diff --git a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts deleted file mode 100644 index f857bc85bc..0000000000 --- a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts +++ /dev/null @@ -1,3 +0,0 @@ -import { RootState } from 'app/store/store'; - -export const modelSelector = (state: RootState) => state.models; diff --git a/invokeai/frontend/web/src/features/system/store/modelSlice.ts b/invokeai/frontend/web/src/features/system/store/modelSlice.ts deleted file mode 100644 index ed38425872..0000000000 --- a/invokeai/frontend/web/src/features/system/store/modelSlice.ts +++ /dev/null @@ -1,47 +0,0 @@ -import { createEntityAdapter } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; -import { CkptModelInfo, DiffusersModelInfo } from 'services/api'; -import { receivedModels } from 'services/thunks/model'; - -export type Model = (CkptModelInfo | DiffusersModelInfo) & { - name: string; -}; - -export const modelsAdapter = createEntityAdapter({ - selectId: (model) => model.name, - sortComparer: (a, b) => a.name.localeCompare(b.name), -}); - -export const initialModelsState = modelsAdapter.getInitialState(); - -export type ModelsState = typeof initialModelsState; - -export const modelsSlice = createSlice({ - name: 'models', - initialState: initialModelsState, - reducers: { - modelAdded: modelsAdapter.upsertOne, - }, - extraReducers(builder) { - /** - * Received Models - FULFILLED - */ - builder.addCase(receivedModels.fulfilled, (state, action) => { - const models = action.payload; - modelsAdapter.setAll(state, models); - }); - }, -}); - -export const { - selectAll: selectModelsAll, - selectById: selectModelsById, - selectEntities: selectModelsEntities, - selectIds: selectModelsIds, - selectTotal: selectModelsTotal, -} = modelsAdapter.getSelectors((state) => state.models); - -export const { modelAdded } = modelsSlice.actions; - -export default modelsSlice.reducer; diff --git a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts deleted file mode 100644 index aa9fb057e1..0000000000 --- a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { ModelsState } from './modelSlice'; - -/** - * Models slice persist denylist - */ -export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids']; diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index b17f497f6c..688f69c1f7 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,20 +1,12 @@ import { UseToastOptions } from '@chakra-ui/react'; -import { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import * as InvokeAI from 'app/types/invokeai'; -import { ProgressImage } from 'services/events/types'; -import { makeToast } from '../../../app/components/Toaster'; -import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session'; -import { receivedModels } from 'services/thunks/model'; -import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice'; -import { LogLevelName } from 'roarr'; import { InvokeLogLevel } from 'app/logging/useLogger'; -import { TFuncKey } from 'i18next'; -import { t } from 'i18next'; import { userInvoked } from 'app/store/actions'; -import { LANGUAGES } from '../components/LanguagePicker'; -import { imageUploaded } from 'services/thunks/image'; +import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice'; +import { TFuncKey, t } from 'i18next'; +import { LogLevelName } from 'roarr'; import { appSocketConnected, appSocketDisconnected, @@ -26,6 +18,11 @@ import { appSocketSubscribed, appSocketUnsubscribed, } from 'services/events/actions'; +import { ProgressImage } from 'services/events/types'; +import { imageUploaded } from 'services/thunks/image'; +import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session'; +import { makeToast } from '../../../app/components/Toaster'; +import { LANGUAGES } from '../components/LanguagePicker'; export type CancelStrategy = 'immediate' | 'scheduled'; @@ -95,6 +92,7 @@ export interface SystemState { shouldAntialiasProgressImage: boolean; language: keyof typeof LANGUAGES; isUploading: boolean; + boardIdToAddTo?: string; } export const initialSystemState: SystemState = { @@ -225,6 +223,7 @@ export const systemSlice = createSlice({ */ builder.addCase(appSocketSubscribed, (state, action) => { state.sessionId = action.payload.sessionId; + state.boardIdToAddTo = action.payload.boardId; state.canceledSession = ''; }); @@ -233,6 +232,7 @@ export const systemSlice = createSlice({ */ builder.addCase(appSocketUnsubscribed, (state) => { state.sessionId = null; + state.boardIdToAddTo = undefined; }); /** @@ -376,13 +376,6 @@ export const systemSlice = createSlice({ ); }); - /** - * Received available models from the backend - */ - builder.addCase(receivedModels.fulfilled, (state) => { - state.wereModelsReceived = true; - }); - /** * OpenAPI schema was parsed */ diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 7481a5daad..8ce42494e5 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -8,6 +8,10 @@ export type { OpenAPIConfig } from './core/OpenAPI'; export type { AddInvocation } from './models/AddInvocation'; export type { BaseModelType } from './models/BaseModelType'; +export type { BoardChanges } from './models/BoardChanges'; +export type { BoardDTO } from './models/BoardDTO'; +export type { Body_create_board_image } from './models/Body_create_board_image'; +export type { Body_remove_board_image } from './models/Body_remove_board_image'; export type { Body_upload_image } from './models/Body_upload_image'; export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation'; export type { CkptModelInfo } from './models/CkptModelInfo'; @@ -21,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField'; export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation'; export type { ControlField } from './models/ControlField'; export type { ControlNetInvocation } from './models/ControlNetInvocation'; +export type { ControlNetModelConfig } from './models/ControlNetModelConfig'; +export type { ControlNetModelFormat } from './models/ControlNetModelFormat'; export type { ControlOutput } from './models/ControlOutput'; export type { CreateModelRequest } from './models/CreateModelRequest'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; @@ -63,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation'; export type { InpaintInvocation } from './models/InpaintInvocation'; export type { IntCollectionOutput } from './models/IntCollectionOutput'; export type { IntOutput } from './models/IntOutput'; -export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config'; -export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig'; -export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig'; -export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config'; -export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config'; export type { IterateInvocation } from './models/IterateInvocation'; export type { IterateInvocationOutput } from './models/IterateInvocationOutput'; export type { LatentsField } from './models/LatentsField'; @@ -83,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation'; export type { LoraInfo } from './models/LoraInfo'; export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation'; export type { LoraLoaderOutput } from './models/LoraLoaderOutput'; +export type { LoRAModelConfig } from './models/LoRAModelConfig'; +export type { LoRAModelFormat } from './models/LoRAModelFormat'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskOutput } from './models/MaskOutput'; export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation'; @@ -98,12 +98,15 @@ export type { MultiplyInvocation } from './models/MultiplyInvocation'; export type { NoiseInvocation } from './models/NoiseInvocation'; export type { NoiseOutput } from './models/NoiseOutput'; export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation'; +export type { OffsetPaginatedResults_BoardDTO_ } from './models/OffsetPaginatedResults_BoardDTO_'; export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_'; export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation'; export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_'; export type { ParamFloatInvocation } from './models/ParamFloatInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation'; export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation'; +export type { PipelineModelField } from './models/PipelineModelField'; +export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation'; export type { PromptCollectionOutput } from './models/PromptCollectionOutput'; export type { PromptOutput } from './models/PromptOutput'; export type { RandomIntInvocation } from './models/RandomIntInvocation'; @@ -115,20 +118,28 @@ export type { ResourceOrigin } from './models/ResourceOrigin'; export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation'; export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation'; export type { SchedulerPredictionType } from './models/SchedulerPredictionType'; -export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation'; -export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation'; export type { ShowImageInvocation } from './models/ShowImageInvocation'; +export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig'; +export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig'; +export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat'; +export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig'; +export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig'; +export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat'; export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation'; export type { SubModelType } from './models/SubModelType'; export type { SubtractInvocation } from './models/SubtractInvocation'; export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation'; +export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig'; export type { UNetField } from './models/UNetField'; export type { UpscaleInvocation } from './models/UpscaleInvocation'; export type { VaeField } from './models/VaeField'; +export type { VaeModelConfig } from './models/VaeModelConfig'; +export type { VaeModelFormat } from './models/VaeModelFormat'; export type { VaeRepo } from './models/VaeRepo'; export type { ValidationError } from './models/ValidationError'; export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation'; +export { BoardsService } from './services/BoardsService'; export { ImagesService } from './services/ImagesService'; export { ModelsService } from './services/ModelsService'; export { SessionsService } from './services/SessionsService'; diff --git a/invokeai/frontend/web/src/services/api/models/BoardChanges.ts b/invokeai/frontend/web/src/services/api/models/BoardChanges.ts new file mode 100644 index 0000000000..fb2bfa0cd9 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/BoardChanges.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type BoardChanges = { + /** + * The board's new name. + */ + board_name?: string; + /** + * The name of the board's new cover image. + */ + cover_image_name?: string; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/BoardDTO.ts b/invokeai/frontend/web/src/services/api/models/BoardDTO.ts new file mode 100644 index 0000000000..bbcc6f1dd6 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/BoardDTO.ts @@ -0,0 +1,38 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * Deserialized board record with cover image URL and image count. + */ +export type BoardDTO = { + /** + * The unique ID of the board. + */ + board_id: string; + /** + * The name of the board. + */ + board_name: string; + /** + * The created timestamp of the board. + */ + created_at: string; + /** + * The updated timestamp of the board. + */ + updated_at: string; + /** + * The deleted timestamp of the board. + */ + deleted_at?: string; + /** + * The name of the board's cover image. + */ + cover_image_name?: string; + /** + * The number of images in the board. + */ + image_count: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/Body_create_board_image.ts b/invokeai/frontend/web/src/services/api/models/Body_create_board_image.ts new file mode 100644 index 0000000000..47f8537eaa --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/Body_create_board_image.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type Body_create_board_image = { + /** + * The id of the board to add to + */ + board_id: string; + /** + * The name of the image to add + */ + image_name: string; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/Body_remove_board_image.ts b/invokeai/frontend/web/src/services/api/models/Body_remove_board_image.ts new file mode 100644 index 0000000000..6f5a3652d0 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/Body_remove_board_image.ts @@ -0,0 +1,15 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +export type Body_remove_board_image = { + /** + * The id of the board + */ + board_id: string; + /** + * The name of the image to remove + */ + image_name: string; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts b/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts new file mode 100644 index 0000000000..60e2958f5c --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts @@ -0,0 +1,18 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { ControlNetModelFormat } from './ControlNetModelFormat'; +import type { ModelError } from './ModelError'; + +export type ControlNetModelConfig = { + name: string; + base_model: BaseModelType; + type: 'controlnet'; + path: string; + description?: string; + model_format: ControlNetModelFormat; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts b/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts new file mode 100644 index 0000000000..500b3e8f8c --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type ControlNetModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/Graph.ts b/invokeai/frontend/web/src/services/api/models/Graph.ts index e148954f16..5fba3d8311 100644 --- a/invokeai/frontend/web/src/services/api/models/Graph.ts +++ b/invokeai/frontend/web/src/services/api/models/Graph.ts @@ -49,6 +49,7 @@ import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorI import type { ParamFloatInvocation } from './ParamFloatInvocation'; import type { ParamIntInvocation } from './ParamIntInvocation'; import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation'; +import type { PipelineModelLoaderInvocation } from './PipelineModelLoaderInvocation'; import type { RandomIntInvocation } from './RandomIntInvocation'; import type { RandomRangeInvocation } from './RandomRangeInvocation'; import type { RangeInvocation } from './RangeInvocation'; @@ -56,8 +57,6 @@ import type { RangeOfSizeInvocation } from './RangeOfSizeInvocation'; import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation'; import type { RestoreFaceInvocation } from './RestoreFaceInvocation'; import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation'; -import type { SD1ModelLoaderInvocation } from './SD1ModelLoaderInvocation'; -import type { SD2ModelLoaderInvocation } from './SD2ModelLoaderInvocation'; import type { ShowImageInvocation } from './ShowImageInvocation'; import type { StepParamEasingInvocation } from './StepParamEasingInvocation'; import type { SubtractInvocation } from './SubtractInvocation'; @@ -73,7 +72,7 @@ export type Graph = { /** * The nodes in this graph */ - nodes?: Record; + nodes?: Record; /** * The connections between nodes and their fields in this graph */ diff --git a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts index 5399d16b8f..4e273e8854 100644 --- a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts +++ b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts @@ -7,7 +7,7 @@ import type { ImageMetadata } from './ImageMetadata'; import type { ResourceOrigin } from './ResourceOrigin'; /** - * Deserialized image record, enriched for the frontend with URLs. + * Deserialized image record, enriched for the frontend. */ export type ImageDTO = { /** @@ -66,5 +66,9 @@ export type ImageDTO = { * A limited subset of the image's generation metadata. Retrieve the image's session for full metadata. */ metadata?: ImageMetadata; + /** + * The id of the board the image belongs to, if one exists. + */ + board_id?: string; }; diff --git a/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts b/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts new file mode 100644 index 0000000000..184a266169 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts @@ -0,0 +1,18 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { LoRAModelFormat } from './LoRAModelFormat'; +import type { ModelError } from './ModelError'; + +export type LoRAModelConfig = { + name: string; + base_model: BaseModelType; + type: 'lora'; + path: string; + description?: string; + model_format: LoRAModelFormat; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts b/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts new file mode 100644 index 0000000000..829f8a7c57 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type LoRAModelFormat = 'lycoris' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/ModelsList.ts b/invokeai/frontend/web/src/services/api/models/ModelsList.ts index a2d88d1967..9186db3e29 100644 --- a/invokeai/frontend/web/src/services/api/models/ModelsList.ts +++ b/invokeai/frontend/web/src/services/api/models/ModelsList.ts @@ -2,16 +2,16 @@ /* tslint:disable */ /* eslint-disable */ -import type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './invokeai__backend__model_management__models__controlnet__ControlNetModel__Config'; -import type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './invokeai__backend__model_management__models__lora__LoRAModel__Config'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig'; -import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig'; -import type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config'; -import type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './invokeai__backend__model_management__models__vae__VaeModel__Config'; +import type { ControlNetModelConfig } from './ControlNetModelConfig'; +import type { LoRAModelConfig } from './LoRAModelConfig'; +import type { StableDiffusion1ModelCheckpointConfig } from './StableDiffusion1ModelCheckpointConfig'; +import type { StableDiffusion1ModelDiffusersConfig } from './StableDiffusion1ModelDiffusersConfig'; +import type { StableDiffusion2ModelCheckpointConfig } from './StableDiffusion2ModelCheckpointConfig'; +import type { StableDiffusion2ModelDiffusersConfig } from './StableDiffusion2ModelDiffusersConfig'; +import type { TextualInversionModelConfig } from './TextualInversionModelConfig'; +import type { VaeModelConfig } from './VaeModelConfig'; export type ModelsList = { - models: Record>>; + models: Array<(StableDiffusion1ModelCheckpointConfig | StableDiffusion1ModelDiffusersConfig | VaeModelConfig | LoRAModelConfig | ControlNetModelConfig | TextualInversionModelConfig | StableDiffusion2ModelCheckpointConfig | StableDiffusion2ModelDiffusersConfig)>; }; diff --git a/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_BoardDTO_.ts b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_BoardDTO_.ts new file mode 100644 index 0000000000..2e4734f469 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_BoardDTO_.ts @@ -0,0 +1,28 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BoardDTO } from './BoardDTO'; + +/** + * Offset-paginated results + */ +export type OffsetPaginatedResults_BoardDTO_ = { + /** + * Items + */ + items: Array; + /** + * Offset from which to retrieve items + */ + offset: number; + /** + * Limit of items to get + */ + limit: number; + /** + * Total number of items in result + */ + total: number; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts b/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts new file mode 100644 index 0000000000..c2f1c07fbf --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts @@ -0,0 +1,20 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; + +/** + * Pipeline model field + */ +export type PipelineModelField = { + /** + * Name of the model + */ + model_name: string; + /** + * Base model + */ + base_model: BaseModelType; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts b/invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts similarity index 52% rename from invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts rename to invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts index f477c11a8d..b8cdb27acf 100644 --- a/invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts +++ b/invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts @@ -2,10 +2,12 @@ /* tslint:disable */ /* eslint-disable */ +import type { PipelineModelField } from './PipelineModelField'; + /** - * Loading submodels of selected model. + * Loads a pipeline model, outputting its submodels. */ -export type SD2ModelLoaderInvocation = { +export type PipelineModelLoaderInvocation = { /** * The id of this node. Must be unique among all nodes. */ @@ -14,10 +16,10 @@ export type SD2ModelLoaderInvocation = { * Whether or not this node is an intermediate node. */ is_intermediate?: boolean; - type?: 'sd2_model_loader'; + type?: 'pipeline_model_loader'; /** - * Model to load + * The model to load */ - model_name?: string; + model: PipelineModelField; }; diff --git a/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts b/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts deleted file mode 100644 index 9a8a23077a..0000000000 --- a/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts +++ /dev/null @@ -1,23 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -/** - * Loading submodels of selected model. - */ -export type SD1ModelLoaderInvocation = { - /** - * The id of this node. Must be unique among all nodes. - */ - id: string; - /** - * Whether or not this node is an intermediate node. - */ - is_intermediate?: boolean; - type?: 'sd1_model_loader'; - /** - * Model to load - */ - model_name?: string; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts similarity index 60% rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts rename to invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts index 6bdcb87dd4..be7077cc53 100644 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts @@ -2,14 +2,17 @@ /* tslint:disable */ /* eslint-disable */ +import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; import type { ModelVariantType } from './ModelVariantType'; -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig = { +export type StableDiffusion1ModelCheckpointConfig = { + name: string; + base_model: BaseModelType; + type: 'pipeline'; path: string; description?: string; - format: 'checkpoint'; - default?: boolean; + model_format: 'checkpoint'; error?: ModelError; vae?: string; config?: string; diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts similarity index 59% rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts rename to invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts index c88e042178..befe014605 100644 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts @@ -2,14 +2,17 @@ /* tslint:disable */ /* eslint-disable */ +import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; import type { ModelVariantType } from './ModelVariantType'; -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig = { +export type StableDiffusion1ModelDiffusersConfig = { + name: string; + base_model: BaseModelType; + type: 'pipeline'; path: string; description?: string; - format: 'diffusers'; - default?: boolean; + model_format: 'diffusers'; error?: ModelError; vae?: string; variant: ModelVariantType; diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts new file mode 100644 index 0000000000..01b50c2fc0 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type StableDiffusion1ModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts similarity index 69% rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts rename to invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts index ec2ae4a845..dadd7cac9b 100644 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts @@ -2,15 +2,18 @@ /* tslint:disable */ /* eslint-disable */ +import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; import type { ModelVariantType } from './ModelVariantType'; import type { SchedulerPredictionType } from './SchedulerPredictionType'; -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig = { +export type StableDiffusion2ModelCheckpointConfig = { + name: string; + base_model: BaseModelType; + type: 'pipeline'; path: string; description?: string; - format: 'checkpoint'; - default?: boolean; + model_format: 'checkpoint'; error?: ModelError; vae?: string; config?: string; diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts similarity index 68% rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts rename to invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts index 67b897d9d9..1e4a34c5dc 100644 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts @@ -2,15 +2,18 @@ /* tslint:disable */ /* eslint-disable */ +import type { BaseModelType } from './BaseModelType'; import type { ModelError } from './ModelError'; import type { ModelVariantType } from './ModelVariantType'; import type { SchedulerPredictionType } from './SchedulerPredictionType'; -export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig = { +export type StableDiffusion2ModelDiffusersConfig = { + name: string; + base_model: BaseModelType; + type: 'pipeline'; path: string; description?: string; - format: 'diffusers'; - default?: boolean; + model_format: 'diffusers'; error?: ModelError; vae?: string; variant: ModelVariantType; diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts new file mode 100644 index 0000000000..7e7b895231 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type StableDiffusion2ModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts b/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts new file mode 100644 index 0000000000..97d6aa7ffa --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts @@ -0,0 +1,17 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { ModelError } from './ModelError'; + +export type TextualInversionModelConfig = { + name: string; + base_model: BaseModelType; + type: 'embedding'; + path: string; + description?: string; + model_format: null; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts b/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts new file mode 100644 index 0000000000..a73ee0aa32 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts @@ -0,0 +1,18 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +import type { BaseModelType } from './BaseModelType'; +import type { ModelError } from './ModelError'; +import type { VaeModelFormat } from './VaeModelFormat'; + +export type VaeModelConfig = { + name: string; + base_model: BaseModelType; + type: 'vae'; + path: string; + description?: string; + model_format: VaeModelFormat; + error?: ModelError; +}; + diff --git a/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts b/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts new file mode 100644 index 0000000000..497f81d16f --- /dev/null +++ b/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts @@ -0,0 +1,8 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ + +/** + * An enumeration. + */ +export type VaeModelFormat = 'checkpoint' | 'diffusers'; diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts deleted file mode 100644 index f8decdb341..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__controlnet__ControlNetModel__Config = { - path: string; - description?: string; - format: ('checkpoint' | 'diffusers'); - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts deleted file mode 100644 index 614749a2c5..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__lora__LoRAModel__Config = { - path: string; - description?: string; - format: ('lycoris' | 'diffusers'); - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts deleted file mode 100644 index f23d5002e3..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config = { - path: string; - description?: string; - format: null; - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts deleted file mode 100644 index d9314a6063..0000000000 --- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts +++ /dev/null @@ -1,14 +0,0 @@ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ - -import type { ModelError } from './ModelError'; - -export type invokeai__backend__model_management__models__vae__VaeModel__Config = { - path: string; - description?: string; - format: ('checkpoint' | 'diffusers'); - default?: boolean; - error?: ModelError; -}; - diff --git a/invokeai/frontend/web/src/services/api/services/BoardsService.ts b/invokeai/frontend/web/src/services/api/services/BoardsService.ts new file mode 100644 index 0000000000..236c765cb9 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/services/BoardsService.ts @@ -0,0 +1,247 @@ +/* istanbul ignore file */ +/* tslint:disable */ +/* eslint-disable */ +import type { BoardChanges } from '../models/BoardChanges'; +import type { BoardDTO } from '../models/BoardDTO'; +import type { Body_create_board_image } from '../models/Body_create_board_image'; +import type { Body_remove_board_image } from '../models/Body_remove_board_image'; +import type { OffsetPaginatedResults_BoardDTO_ } from '../models/OffsetPaginatedResults_BoardDTO_'; +import type { OffsetPaginatedResults_ImageDTO_ } from '../models/OffsetPaginatedResults_ImageDTO_'; + +import type { CancelablePromise } from '../core/CancelablePromise'; +import { OpenAPI } from '../core/OpenAPI'; +import { request as __request } from '../core/request'; + +export class BoardsService { + + /** + * List Boards + * Gets a list of boards + * @returns any Successful Response + * @throws ApiError + */ + public static listBoards({ + all, + offset, + limit, + }: { + /** + * Whether to list all boards + */ + all?: boolean, + /** + * The page offset + */ + offset?: number, + /** + * The number of boards per page + */ + limit?: number, + }): CancelablePromise<(OffsetPaginatedResults_BoardDTO_ | Array)> { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v1/boards/', + query: { + 'all': all, + 'offset': offset, + 'limit': limit, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Create Board + * Creates a board + * @returns BoardDTO The board was created successfully + * @throws ApiError + */ + public static createBoard({ + boardName, + }: { + /** + * The name of the board to create + */ + boardName: string, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v1/boards/', + query: { + 'board_name': boardName, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Get Board + * Gets a board + * @returns BoardDTO Successful Response + * @throws ApiError + */ + public static getBoard({ + boardId, + }: { + /** + * The id of board to get + */ + boardId: string, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v1/boards/{board_id}', + path: { + 'board_id': boardId, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Delete Board + * Deletes a board + * @returns any Successful Response + * @throws ApiError + */ + public static deleteBoard({ + boardId, + }: { + /** + * The id of board to delete + */ + boardId: string, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'DELETE', + url: '/api/v1/boards/{board_id}', + path: { + 'board_id': boardId, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Update Board + * Updates a board + * @returns BoardDTO The board was updated successfully + * @throws ApiError + */ + public static updateBoard({ + boardId, + requestBody, + }: { + /** + * The id of board to update + */ + boardId: string, + requestBody: BoardChanges, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'PATCH', + url: '/api/v1/boards/{board_id}', + path: { + 'board_id': boardId, + }, + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Create Board Image + * Creates a board_image + * @returns any The image was added to a board successfully + * @throws ApiError + */ + public static createBoardImage({ + requestBody, + }: { + requestBody: Body_create_board_image, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v1/board_images/', + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * Remove Board Image + * Deletes a board_image + * @returns any The image was removed from the board successfully + * @throws ApiError + */ + public static removeBoardImage({ + requestBody, + }: { + requestBody: Body_remove_board_image, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'DELETE', + url: '/api/v1/board_images/', + body: requestBody, + mediaType: 'application/json', + errors: { + 422: `Validation Error`, + }, + }); + } + + /** + * List Board Images + * Gets a list of images for a board + * @returns OffsetPaginatedResults_ImageDTO_ Successful Response + * @throws ApiError + */ + public static listBoardImages({ + boardId, + offset, + limit = 10, + }: { + /** + * The id of the board + */ + boardId: string, + /** + * The page offset + */ + offset?: number, + /** + * The number of boards per page + */ + limit?: number, + }): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v1/board_images/{board_id}', + path: { + 'board_id': boardId, + }, + query: { + 'offset': offset, + 'limit': limit, + }, + errors: { + 422: `Validation Error`, + }, + }); + } + +} diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts index d0eae92d4b..bfdef887a0 100644 --- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts +++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts @@ -25,6 +25,7 @@ export class ImagesService { imageOrigin, categories, isIntermediate, + boardId, offset, limit = 10, }: { @@ -40,6 +41,10 @@ export class ImagesService { * Whether to list intermediate images */ isIntermediate?: boolean, + /** + * The board id to filter by + */ + boardId?: string, /** * The page offset */ @@ -56,6 +61,7 @@ export class ImagesService { 'image_origin': imageOrigin, 'categories': categories, 'is_intermediate': isIntermediate, + 'board_id': boardId, 'offset': offset, 'limit': limit, }, diff --git a/invokeai/frontend/web/src/services/api/services/SessionsService.ts b/invokeai/frontend/web/src/services/api/services/SessionsService.ts index 2e4a83b25f..51a36caad1 100644 --- a/invokeai/frontend/web/src/services/api/services/SessionsService.ts +++ b/invokeai/frontend/web/src/services/api/services/SessionsService.ts @@ -51,6 +51,7 @@ import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedR import type { ParamFloatInvocation } from '../models/ParamFloatInvocation'; import type { ParamIntInvocation } from '../models/ParamIntInvocation'; import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation'; +import type { PipelineModelLoaderInvocation } from '../models/PipelineModelLoaderInvocation'; import type { RandomIntInvocation } from '../models/RandomIntInvocation'; import type { RandomRangeInvocation } from '../models/RandomRangeInvocation'; import type { RangeInvocation } from '../models/RangeInvocation'; @@ -58,8 +59,6 @@ import type { RangeOfSizeInvocation } from '../models/RangeOfSizeInvocation'; import type { ResizeLatentsInvocation } from '../models/ResizeLatentsInvocation'; import type { RestoreFaceInvocation } from '../models/RestoreFaceInvocation'; import type { ScaleLatentsInvocation } from '../models/ScaleLatentsInvocation'; -import type { SD1ModelLoaderInvocation } from '../models/SD1ModelLoaderInvocation'; -import type { SD2ModelLoaderInvocation } from '../models/SD2ModelLoaderInvocation'; import type { ShowImageInvocation } from '../models/ShowImageInvocation'; import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation'; import type { SubtractInvocation } from '../models/SubtractInvocation'; @@ -175,7 +174,7 @@ export class SessionsService { * The id of the session */ sessionId: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'POST', @@ -212,7 +211,7 @@ export class SessionsService { * The path to the node in the graph */ nodePath: string, - requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), + requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation), }): CancelablePromise { return __request(OpenAPI, { method: 'PUT', diff --git a/invokeai/frontend/web/src/services/apiSlice.ts b/invokeai/frontend/web/src/services/apiSlice.ts new file mode 100644 index 0000000000..e2d765dd90 --- /dev/null +++ b/invokeai/frontend/web/src/services/apiSlice.ts @@ -0,0 +1,223 @@ +import { + TagDescription, + createApi, + fetchBaseQuery, +} from '@reduxjs/toolkit/query/react'; +import { BoardDTO } from './api/models/BoardDTO'; +import { OffsetPaginatedResults_BoardDTO_ } from './api/models/OffsetPaginatedResults_BoardDTO_'; +import { BoardChanges } from './api/models/BoardChanges'; +import { OffsetPaginatedResults_ImageDTO_ } from './api/models/OffsetPaginatedResults_ImageDTO_'; +import { ImageDTO } from './api/models/ImageDTO'; +import { + FullTagDescription, + TagTypesFrom, + TagTypesFromApi, +} from '@reduxjs/toolkit/dist/query/endpointDefinitions'; +import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; +import { BaseModelType } from './api/models/BaseModelType'; +import { ModelType } from './api/models/ModelType'; +import { ModelsList } from './api/models/ModelsList'; +import { keyBy } from 'lodash-es'; + +type ListBoardsArg = { offset: number; limit: number }; +type UpdateBoardArg = { board_id: string; changes: BoardChanges }; +type AddImageToBoardArg = { board_id: string; image_name: string }; +type RemoveImageFromBoardArg = { board_id: string; image_name: string }; +type ListBoardImagesArg = { board_id: string; offset: number; limit: number }; +type ListModelsArg = { base_model?: BaseModelType; model_type?: ModelType }; + +type ModelConfig = ModelsList['models'][number]; + +const tagTypes = ['Board', 'Image', 'Model']; +type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>; + +const LIST = 'LIST'; + +const modelsAdapter = createEntityAdapter({ + selectId: (model) => getModelId(model), + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); + +const getModelId = ({ base_model, type, name }: ModelConfig) => + `${base_model}/${type}/${name}`; + +export const api = createApi({ + baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }), + reducerPath: 'api', + tagTypes, + endpoints: (build) => ({ + /** + * Models Queries + */ + + listModels: build.query, ListModelsArg>({ + query: (arg) => ({ url: 'models/', params: arg }), + providesTags: (result, error, arg) => { + // any list of boards + const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST }]; + + if (result) { + // and individual tags for each board + tags.push( + ...result.ids.map((id) => ({ + type: 'Model' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: (response: ModelsList, meta, arg) => { + return modelsAdapter.addMany( + modelsAdapter.getInitialState(), + keyBy(response.models, getModelId) + ); + }, + }), + /** + * Boards Queries + */ + listBoards: build.query({ + query: (arg) => ({ url: 'boards/', params: arg }), + providesTags: (result, error, arg) => { + // any list of boards + const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }]; + + if (result) { + // and individual tags for each board + tags.push( + ...result.items.map(({ board_id }) => ({ + type: 'Board' as const, + id: board_id, + })) + ); + } + + return tags; + }, + }), + + listAllBoards: build.query, void>({ + query: () => ({ + url: 'boards/', + params: { all: true }, + }), + providesTags: (result, error, arg) => { + // any list of boards + const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }]; + + if (result) { + // and individual tags for each board + tags.push( + ...result.map(({ board_id }) => ({ + type: 'Board' as const, + id: board_id, + })) + ); + } + + return tags; + }, + }), + + /** + * Boards Mutations + */ + + createBoard: build.mutation({ + query: (board_name) => ({ + url: `boards/`, + method: 'POST', + params: { board_name }, + }), + invalidatesTags: [{ id: 'Board', type: LIST }], + }), + + updateBoard: build.mutation({ + query: ({ board_id, changes }) => ({ + url: `boards/${board_id}`, + method: 'PATCH', + body: changes, + }), + invalidatesTags: (result, error, arg) => [ + { type: 'Board', id: arg.board_id }, + ], + }), + + deleteBoard: build.mutation({ + query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }), + invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }], + }), + + /** + * Board Images Queries + */ + + listBoardImages: build.query< + OffsetPaginatedResults_ImageDTO_, + ListBoardImagesArg + >({ + query: ({ board_id, offset, limit }) => ({ + url: `board_images/${board_id}`, + method: 'DELETE', + body: { offset, limit }, + }), + }), + + /** + * Board Images Mutations + */ + + addImageToBoard: build.mutation({ + query: ({ board_id, image_name }) => ({ + url: `board_images/`, + method: 'POST', + body: { board_id, image_name }, + }), + invalidatesTags: (result, error, arg) => [ + { type: 'Board', id: arg.board_id }, + { type: 'Image', id: arg.image_name }, + ], + }), + + removeImageFromBoard: build.mutation({ + query: ({ board_id, image_name }) => ({ + url: `board_images/`, + method: 'DELETE', + body: { board_id, image_name }, + }), + invalidatesTags: (result, error, arg) => [ + { type: 'Board', id: arg.board_id }, + { type: 'Image', id: arg.image_name }, + ], + }), + + /** + * Image Queries + */ + getImageDTO: build.query({ + query: (image_name) => ({ url: `images/${image_name}/metadata` }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [{ type: 'Image', id: arg }]; + if (result?.board_id) { + tags.push({ type: 'Board', id: result.board_id }); + } + return tags; + }, + }), + }), +}); + +export const { + useListBoardsQuery, + useListAllBoardsQuery, + useCreateBoardMutation, + useUpdateBoardMutation, + useDeleteBoardMutation, + useAddImageToBoardMutation, + useRemoveImageFromBoardMutation, + useListBoardImagesQuery, + useGetImageDTOQuery, + useListModelsQuery, +} = api; diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index 5832cb24b1..ed154b9cd8 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -53,14 +53,14 @@ export const appSocketDisconnected = createAction( * Do not use. Only for use in middleware. */ export const socketSubscribed = createAction< - BaseSocketPayload & { sessionId: string } + BaseSocketPayload & { sessionId: string; boardId: string | undefined } >('socket/socketSubscribed'); /** * App-level Socket.IO Subscribed */ export const appSocketSubscribed = createAction< - BaseSocketPayload & { sessionId: string } + BaseSocketPayload & { sessionId: string; boardId: string | undefined } >('socket/appSocketSubscribed'); /** diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts index f1eb844f2c..5b427b1690 100644 --- a/invokeai/frontend/web/src/services/events/middleware.ts +++ b/invokeai/frontend/web/src/services/events/middleware.ts @@ -85,6 +85,7 @@ export const socketMiddleware = () => { socketSubscribed({ sessionId: sessionId, timestamp: getTimestamp(), + boardId: getState().boards.selectedBoardId, }) ); } diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts index 2c4cba510a..62b5864185 100644 --- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts +++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts @@ -44,6 +44,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => { socketSubscribed({ sessionId, timestamp: getTimestamp(), + boardId: getState().boards.selectedBoardId, }) ); } diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts index a0725bf235..fe198cf6f9 100644 --- a/invokeai/frontend/web/src/services/thunks/image.ts +++ b/invokeai/frontend/web/src/services/thunks/image.ts @@ -1,5 +1,6 @@ import { createAppAsyncThunk } from 'app/store/storeUtils'; import { selectImagesAll } from 'features/gallery/store/imagesSlice'; +import { size } from 'lodash-es'; import { ImagesService } from 'services/api'; type imageUrlsReceivedArg = Parameters< @@ -121,25 +122,61 @@ type ImagesListedArg = Parameters< export const IMAGES_PER_PAGE = 20; +const DEFAULT_IMAGES_LISTED_ARG = { + isIntermediate: false, + limit: IMAGES_PER_PAGE, +}; + /** * `ImagesService.listImagesWithMetadata()` thunk */ export const receivedPageOfImages = createAppAsyncThunk( 'api/receivedPageOfImages', - async (_, { getState }) => { + async (arg: ImagesListedArg, { getState }) => { const state = getState(); const { categories } = state.images; + const { selectedBoardId } = state.boards; - const totalImagesInFilter = selectImagesAll(state).filter((i) => - categories.includes(i.image_category) - ).length; - - const response = await ImagesService.listImagesWithMetadata({ - categories, - isIntermediate: false, - offset: totalImagesInFilter, - limit: IMAGES_PER_PAGE, + const images = selectImagesAll(state).filter((i) => { + const isInCategory = categories.includes(i.image_category); + const isInSelectedBoard = selectedBoardId + ? i.board_id === selectedBoardId + : true; + return isInCategory && isInSelectedBoard; }); + + let queryArg: ReceivedImagesArg = {}; + + if (size(arg)) { + queryArg = { + ...DEFAULT_IMAGES_LISTED_ARG, + offset: images.length, + ...arg, + }; + } else { + queryArg = { + ...DEFAULT_IMAGES_LISTED_ARG, + categories, + offset: images.length, + }; + } + + const response = await ImagesService.listImagesWithMetadata(queryArg); + return response; + } +); + +type ReceivedImagesArg = Parameters< + (typeof ImagesService)['listImagesWithMetadata'] +>[0]; + +/** + * `ImagesService.listImagesWithMetadata()` thunk + */ +export const receivedImages = createAppAsyncThunk( + 'api/receivedImages', + async (arg: ReceivedImagesArg, { getState }) => { + const response = await ImagesService.listImagesWithMetadata(arg); return response; } ); diff --git a/invokeai/frontend/web/src/services/thunks/model.ts b/invokeai/frontend/web/src/services/thunks/model.ts deleted file mode 100644 index 97f2bd8016..0000000000 --- a/invokeai/frontend/web/src/services/thunks/model.ts +++ /dev/null @@ -1,33 +0,0 @@ -import { log } from 'app/logging/useLogger'; -import { createAppAsyncThunk } from 'app/store/storeUtils'; -import { Model } from 'features/system/store/modelSlice'; -import { reduce, size } from 'lodash-es'; -import { ModelsService } from 'services/api'; - -const models = log.child({ namespace: 'model' }); - -export const IMAGES_PER_PAGE = 20; - -export const receivedModels = createAppAsyncThunk( - 'models/receivedModels', - async (_) => { - const response = await ModelsService.listModels(); - - const deserializedModels = reduce( - response.models['sd-1']['pipeline'], - (modelsAccumulator, model, modelName) => { - modelsAccumulator[modelName] = { ...model, name: modelName }; - - return modelsAccumulator; - }, - {} as Record - ); - - models.info( - { response }, - `Received ${size(response.models['sd-1']['pipeline'])} models` - ); - - return deserializedModels; - } -); diff --git a/invokeai/frontend/web/src/services/types/guards.ts b/invokeai/frontend/web/src/services/types/guards.ts index 334c04e6ed..7ac0d95e6a 100644 --- a/invokeai/frontend/web/src/services/types/guards.ts +++ b/invokeai/frontend/web/src/services/types/guards.ts @@ -11,6 +11,7 @@ import { LatentsOutput, ResourceOrigin, ImageDTO, + BoardDTO, } from 'services/api'; export const isImageDTO = (obj: unknown): obj is ImageDTO => { @@ -29,6 +30,16 @@ export const isImageDTO = (obj: unknown): obj is ImageDTO => { ); }; +export const isBoardDTO = (obj: unknown): obj is BoardDTO => { + return ( + isObject(obj) && + 'board_id' in obj && + isString(obj?.board_id) && + 'board_name' in obj && + isString(obj?.board_name) + ); +}; + export const isImageOutput = ( output: GraphExecutionState['results'][string] ): output is ImageOutput => output.type === 'image_output'; diff --git a/invokeai/frontend/web/src/theme/colors/greenTea.ts b/invokeai/frontend/web/src/theme/colors/greenTea.ts index ffecbf2ffa..318aecbc61 100644 --- a/invokeai/frontend/web/src/theme/colors/greenTea.ts +++ b/invokeai/frontend/web/src/theme/colors/greenTea.ts @@ -4,8 +4,8 @@ import { generateColorPalette } from '../util/generateColorPalette'; export const greenTeaThemeColors: InvokeAIThemeColors = { base: generateColorPalette(223, 10), baseAlpha: generateColorPalette(223, 10, false, true), - accent: generateColorPalette(155, 80), - accentAlpha: generateColorPalette(155, 80, false, true), + accent: generateColorPalette(160, 60), + accentAlpha: generateColorPalette(160, 60, false, true), working: generateColorPalette(47, 68), workingAlpha: generateColorPalette(47, 68, false, true), warning: generateColorPalette(28, 75), @@ -14,5 +14,5 @@ export const greenTeaThemeColors: InvokeAIThemeColors = { okAlpha: generateColorPalette(122, 49, false, true), error: generateColorPalette(0, 50), errorAlpha: generateColorPalette(0, 50, false, true), - gridLineColor: 'rgba(255, 255, 255, 0.2)', + gridLineColor: 'rgba(255, 255, 255, 0.15)', }; diff --git a/invokeai/frontend/web/src/theme/colors/invokeAI.ts b/invokeai/frontend/web/src/theme/colors/invokeAI.ts index c39b3bed81..82db58bd35 100644 --- a/invokeai/frontend/web/src/theme/colors/invokeAI.ts +++ b/invokeai/frontend/web/src/theme/colors/invokeAI.ts @@ -2,8 +2,8 @@ import { InvokeAIThemeColors } from 'theme/themeTypes'; import { generateColorPalette } from 'theme/util/generateColorPalette'; export const invokeAIThemeColors: InvokeAIThemeColors = { - base: generateColorPalette(225, 15), - baseAlpha: generateColorPalette(225, 15, false, true), + base: generateColorPalette(220, 15), + baseAlpha: generateColorPalette(220, 15, false, true), accent: generateColorPalette(250, 50), accentAlpha: generateColorPalette(250, 50, false, true), working: generateColorPalette(47, 67), @@ -14,5 +14,5 @@ export const invokeAIThemeColors: InvokeAIThemeColors = { okAlpha: generateColorPalette(113, 70, false, true), error: generateColorPalette(0, 76), errorAlpha: generateColorPalette(0, 76, false, true), - gridLineColor: 'rgba(255, 255, 255, 0.2)', + gridLineColor: 'rgba(150, 150, 180, 0.15)', }; diff --git a/invokeai/frontend/web/src/theme/colors/lightTheme.ts b/invokeai/frontend/web/src/theme/colors/lightTheme.ts index 2a7a05bbd2..2fdbd1a769 100644 --- a/invokeai/frontend/web/src/theme/colors/lightTheme.ts +++ b/invokeai/frontend/web/src/theme/colors/lightTheme.ts @@ -14,5 +14,5 @@ export const lightThemeColors: InvokeAIThemeColors = { okAlpha: generateColorPalette(122, 49, true, true), error: generateColorPalette(0, 50, true), errorAlpha: generateColorPalette(0, 50, true, true), - gridLineColor: 'rgba(0, 0, 0, 0.2)', + gridLineColor: 'rgba(0, 0, 0, 0.15)', }; diff --git a/invokeai/frontend/web/src/theme/colors/oceanBlue.ts b/invokeai/frontend/web/src/theme/colors/oceanBlue.ts index adfb8ab288..952e0a5066 100644 --- a/invokeai/frontend/web/src/theme/colors/oceanBlue.ts +++ b/invokeai/frontend/web/src/theme/colors/oceanBlue.ts @@ -14,5 +14,5 @@ export const oceanBlueColors: InvokeAIThemeColors = { okAlpha: generateColorPalette(122, 49, false, true), error: generateColorPalette(0, 100), errorAlpha: generateColorPalette(0, 100, false, true), - gridLineColor: 'rgba(136, 148, 184, 0.2)', + gridLineColor: 'rgba(136, 148, 184, 0.15)', };