diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py
index 5599d569d5..efeb778922 100644
--- a/invokeai/app/api/dependencies.py
+++ b/invokeai/app/api/dependencies.py
@@ -2,8 +2,17 @@
from logging import Logger
import os
+from invokeai.app.services.board_image_record_storage import (
+ SqliteBoardImageRecordStorage,
+)
+from invokeai.app.services.board_images import (
+ BoardImagesService,
+ BoardImagesServiceDependencies,
+)
+from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
+from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
-from invokeai.app.services.images import ImageService
+from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
@@ -57,7 +66,7 @@ class ApiDependencies:
# TODO: build a file/path manager?
db_location = config.db_path
- db_location.parent.mkdir(parents=True,exist_ok=True)
+ db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
@@ -72,14 +81,40 @@ class ApiDependencies:
DiskLatentsStorage(f"{output_folder}/latents")
)
+ board_record_storage = SqliteBoardRecordStorage(db_location)
+ board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
+
+ boards = BoardService(
+ services=BoardServiceDependencies(
+ board_image_record_storage=board_image_record_storage,
+ board_record_storage=board_record_storage,
+ image_record_storage=image_record_storage,
+ url=urls,
+ logger=logger,
+ )
+ )
+
+ board_images = BoardImagesService(
+ services=BoardImagesServiceDependencies(
+ board_image_record_storage=board_image_record_storage,
+ board_record_storage=board_record_storage,
+ image_record_storage=image_record_storage,
+ url=urls,
+ logger=logger,
+ )
+ )
+
images = ImageService(
- image_record_storage=image_record_storage,
- image_file_storage=image_file_storage,
- metadata=metadata,
- url=urls,
- logger=logger,
- names=names,
- graph_execution_manager=graph_execution_manager,
+ services=ImageServiceDependencies(
+ board_image_record_storage=board_image_record_storage,
+ image_record_storage=image_record_storage,
+ image_file_storage=image_file_storage,
+ metadata=metadata,
+ url=urls,
+ logger=logger,
+ names=names,
+ graph_execution_manager=graph_execution_manager,
+ )
)
services = InvocationServices(
@@ -87,6 +122,8 @@ class ApiDependencies:
events=events,
latents=latents,
images=images,
+ boards=boards,
+ board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
diff --git a/invokeai/app/api/routers/board_images.py b/invokeai/app/api/routers/board_images.py
new file mode 100644
index 0000000000..b206ab500d
--- /dev/null
+++ b/invokeai/app/api/routers/board_images.py
@@ -0,0 +1,69 @@
+from fastapi import Body, HTTPException, Path, Query
+from fastapi.routing import APIRouter
+from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
+from invokeai.app.services.image_record_storage import OffsetPaginatedResults
+from invokeai.app.services.models.board_record import BoardDTO
+from invokeai.app.services.models.image_record import ImageDTO
+
+from ..dependencies import ApiDependencies
+
+board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
+
+
+@board_images_router.post(
+ "/",
+ operation_id="create_board_image",
+ responses={
+ 201: {"description": "The image was added to a board successfully"},
+ },
+ status_code=201,
+)
+async def create_board_image(
+ board_id: str = Body(description="The id of the board to add to"),
+ image_name: str = Body(description="The name of the image to add"),
+):
+ """Creates a board_image"""
+ try:
+ result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
+ return result
+ except Exception as e:
+ raise HTTPException(status_code=500, detail="Failed to add to board")
+
+@board_images_router.delete(
+ "/",
+ operation_id="remove_board_image",
+ responses={
+ 201: {"description": "The image was removed from the board successfully"},
+ },
+ status_code=201,
+)
+async def remove_board_image(
+ board_id: str = Body(description="The id of the board"),
+ image_name: str = Body(description="The name of the image to remove"),
+):
+ """Deletes a board_image"""
+ try:
+ result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
+ return result
+ except Exception as e:
+ raise HTTPException(status_code=500, detail="Failed to update board")
+
+
+
+@board_images_router.get(
+ "/{board_id}",
+ operation_id="list_board_images",
+ response_model=OffsetPaginatedResults[ImageDTO],
+)
+async def list_board_images(
+ board_id: str = Path(description="The id of the board"),
+ offset: int = Query(default=0, description="The page offset"),
+ limit: int = Query(default=10, description="The number of boards per page"),
+) -> OffsetPaginatedResults[ImageDTO]:
+ """Gets a list of images for a board"""
+
+ results = ApiDependencies.invoker.services.board_images.get_images_for_board(
+ board_id,
+ )
+ return results
+
diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py
new file mode 100644
index 0000000000..55cd7c8ca2
--- /dev/null
+++ b/invokeai/app/api/routers/boards.py
@@ -0,0 +1,108 @@
+from typing import Optional, Union
+from fastapi import Body, HTTPException, Path, Query
+from fastapi.routing import APIRouter
+from invokeai.app.services.board_record_storage import BoardChanges
+from invokeai.app.services.image_record_storage import OffsetPaginatedResults
+from invokeai.app.services.models.board_record import BoardDTO
+
+from ..dependencies import ApiDependencies
+
+boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
+
+
+@boards_router.post(
+ "/",
+ operation_id="create_board",
+ responses={
+ 201: {"description": "The board was created successfully"},
+ },
+ status_code=201,
+ response_model=BoardDTO,
+)
+async def create_board(
+ board_name: str = Query(description="The name of the board to create"),
+) -> BoardDTO:
+ """Creates a board"""
+ try:
+ result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
+ return result
+ except Exception as e:
+ raise HTTPException(status_code=500, detail="Failed to create board")
+
+
+@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
+async def get_board(
+ board_id: str = Path(description="The id of board to get"),
+) -> BoardDTO:
+ """Gets a board"""
+
+ try:
+ result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
+ return result
+ except Exception as e:
+ raise HTTPException(status_code=404, detail="Board not found")
+
+
+@boards_router.patch(
+ "/{board_id}",
+ operation_id="update_board",
+ responses={
+ 201: {
+ "description": "The board was updated successfully",
+ },
+ },
+ status_code=201,
+ response_model=BoardDTO,
+)
+async def update_board(
+ board_id: str = Path(description="The id of board to update"),
+ changes: BoardChanges = Body(description="The changes to apply to the board"),
+) -> BoardDTO:
+ """Updates a board"""
+ try:
+ result = ApiDependencies.invoker.services.boards.update(
+ board_id=board_id, changes=changes
+ )
+ return result
+ except Exception as e:
+ raise HTTPException(status_code=500, detail="Failed to update board")
+
+
+@boards_router.delete("/{board_id}", operation_id="delete_board")
+async def delete_board(
+ board_id: str = Path(description="The id of board to delete"),
+) -> None:
+ """Deletes a board"""
+
+ try:
+ ApiDependencies.invoker.services.boards.delete(board_id=board_id)
+ except Exception as e:
+ # TODO: Does this need any exception handling at all?
+ pass
+
+
+@boards_router.get(
+ "/",
+ operation_id="list_boards",
+ response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
+)
+async def list_boards(
+ all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
+ offset: Optional[int] = Query(default=None, description="The page offset"),
+ limit: Optional[int] = Query(
+ default=None, description="The number of boards per page"
+ ),
+) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
+ """Gets a list of boards"""
+ if all:
+ return ApiDependencies.invoker.services.boards.get_all()
+ elif offset is not None and limit is not None:
+ return ApiDependencies.invoker.services.boards.get_many(
+ offset,
+ limit,
+ )
+ else:
+ raise HTTPException(
+ status_code=400,
+ detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
+ )
diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py
index 11453d97f1..a8c84b81b9 100644
--- a/invokeai/app/api/routers/images.py
+++ b/invokeai/app/api/routers/images.py
@@ -221,6 +221,9 @@ async def list_images_with_metadata(
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
+ board_id: Optional[str] = Query(
+ default=None, description="The board id to filter by"
+ ),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]:
@@ -232,6 +235,7 @@ async def list_images_with_metadata(
image_origin,
categories,
is_intermediate,
+ board_id,
)
return image_dtos
diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py
index f510279f18..50d645eb57 100644
--- a/invokeai/app/api/routers/models.py
+++ b/invokeai/app/api/routers/models.py
@@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
-from invokeai.backend.model_management.models import get_all_model_configs
-MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
+from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
+MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
- models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
- #models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
+ models: list[MODEL_CONFIGS]
@models_router.get(
@@ -72,10 +71,10 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }},
)
async def list_models(
- base_model: BaseModelType = Query(
+ base_model: Optional[BaseModelType] = Query(
default=None, description="Base model"
),
- model_type: ModelType = Query(
+ model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList:
diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py
index fa46762d56..e14c58bab7 100644
--- a/invokeai/app/api_app.py
+++ b/invokeai/app/api_app.py
@@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config)
import invokeai.frontend.web as web_dir
from .api.dependencies import ApiDependencies
-from .api.routers import sessions, models, images
+from .api.routers import sessions, models, images, boards, board_images
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
@@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
+app.include_router(boards.boards_router, prefix="/api")
+
+app.include_router(board_images.board_images_router, prefix="/api")
+
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi():
@@ -116,6 +120,22 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref
+ from invokeai.backend.model_management.models import get_model_config_enums
+ for model_config_format_enum in set(get_model_config_enums()):
+ name = model_config_format_enum.__qualname__
+
+ if name in openapi_schema["components"]["schemas"]:
+ # print(f"Config with name {name} already defined")
+ continue
+
+ # "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
+ openapi_schema["components"]["schemas"][name] = dict(
+ title=name,
+ description="An enumeration.",
+ type="string",
+ enum=list(v.value for v in model_config_format_enum),
+ )
+
app.openapi_schema = openapi_schema
return app.openapi_schema
diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py
index 9d77cadf8c..b77aa5dafd 100644
--- a/invokeai/app/invocations/model.py
+++ b/invokeai/app/invocations/model.py
@@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on
-class SD1ModelLoaderInvocation(BaseInvocation):
- """Loading submodels of selected model."""
+class PipelineModelField(BaseModel):
+ """Pipeline model field"""
- type: Literal["sd1_model_loader"] = "sd1_model_loader"
+ model_name: str = Field(description="Name of the model")
+ base_model: BaseModelType = Field(description="Base model")
- model_name: str = Field(default="", description="Model to load")
+
+class PipelineModelLoaderInvocation(BaseInvocation):
+ """Loads a pipeline model, outputting its submodels."""
+
+ type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
+
+ model: PipelineModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
@@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation):
"ui": {
"tags": ["model", "loader"],
"type_hints": {
- "model_name": "model" # TODO: rename to model_name?
+ "model": "model"
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
- base_model = BaseModelType.StableDiffusion1 # TODO:
+ base_model = self.model.base_model
+ model_name = self.model.model_name
+ model_type = ModelType.Pipeline
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
- model_name=self.model_name,
+ model_name=model_name,
base_model=base_model,
- model_type=ModelType.Pipeline,
+ model_type=model_type,
):
- raise Exception(f"Unkown model name: {self.model_name}!")
+ raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
@@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation):
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
- model_name=self.model_name,
+ model_name=model_name,
base_model=base_model,
- model_type=ModelType.Pipeline,
+ model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
- model_name=self.model_name,
+ model_name=model_name,
base_model=base_model,
- model_type=ModelType.Pipeline,
+ model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
- model_name=self.model_name,
+ model_name=model_name,
base_model=base_model,
- model_type=ModelType.Pipeline,
+ model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
- model_name=self.model_name,
+ model_name=model_name,
base_model=base_model,
- model_type=ModelType.Pipeline,
+ model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
- model_name=self.model_name,
+ model_name=model_name,
base_model=base_model,
- model_type=ModelType.Pipeline,
- submodel=SubModelType.Vae,
- ),
- )
- )
-
-# TODO: optimize(less code copy)
-class SD2ModelLoaderInvocation(BaseInvocation):
- """Loading submodels of selected model."""
-
- type: Literal["sd2_model_loader"] = "sd2_model_loader"
-
- model_name: str = Field(default="", description="Model to load")
- # TODO: precision?
-
- # Schema customisation
- class Config(InvocationConfig):
- schema_extra = {
- "ui": {
- "tags": ["model", "loader"],
- "type_hints": {
- "model_name": "model" # TODO: rename to model_name?
- }
- },
- }
-
- def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
-
- base_model = BaseModelType.StableDiffusion2 # TODO:
-
- # TODO: not found exceptions
- if not context.services.model_manager.model_exists(
- model_name=self.model_name,
- base_model=base_model,
- model_type=ModelType.Pipeline,
- ):
- raise Exception(f"Unkown model name: {self.model_name}!")
-
- """
- if not context.services.model_manager.model_exists(
- model_name=self.model_name,
- model_type=SDModelType.Diffusers,
- submodel=SDModelType.Tokenizer,
- ):
- raise Exception(
- f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
- )
-
- if not context.services.model_manager.model_exists(
- model_name=self.model_name,
- model_type=SDModelType.Diffusers,
- submodel=SDModelType.TextEncoder,
- ):
- raise Exception(
- f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
- )
-
- if not context.services.model_manager.model_exists(
- model_name=self.model_name,
- model_type=SDModelType.Diffusers,
- submodel=SDModelType.UNet,
- ):
- raise Exception(
- f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
- )
- """
-
-
- return ModelLoaderOutput(
- unet=UNetField(
- unet=ModelInfo(
- model_name=self.model_name,
- base_model=base_model,
- model_type=ModelType.Pipeline,
- submodel=SubModelType.UNet,
- ),
- scheduler=ModelInfo(
- model_name=self.model_name,
- base_model=base_model,
- model_type=ModelType.Pipeline,
- submodel=SubModelType.Scheduler,
- ),
- loras=[],
- ),
- clip=ClipField(
- tokenizer=ModelInfo(
- model_name=self.model_name,
- base_model=base_model,
- model_type=ModelType.Pipeline,
- submodel=SubModelType.Tokenizer,
- ),
- text_encoder=ModelInfo(
- model_name=self.model_name,
- base_model=base_model,
- model_type=ModelType.Pipeline,
- submodel=SubModelType.TextEncoder,
- ),
- loras=[],
- ),
- vae=VaeField(
- vae=ModelInfo(
- model_name=self.model_name,
- base_model=base_model,
- model_type=ModelType.Pipeline,
+ model_type=model_type,
submodel=SubModelType.Vae,
),
)
diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py
new file mode 100644
index 0000000000..7aff41860c
--- /dev/null
+++ b/invokeai/app/services/board_image_record_storage.py
@@ -0,0 +1,254 @@
+from abc import ABC, abstractmethod
+import sqlite3
+import threading
+from typing import Union, cast
+from invokeai.app.services.board_record_storage import BoardRecord
+
+from invokeai.app.services.image_record_storage import OffsetPaginatedResults
+from invokeai.app.services.models.image_record import (
+ ImageRecord,
+ deserialize_image_record,
+)
+
+
+class BoardImageRecordStorageBase(ABC):
+ """Abstract base class for the one-to-many board-image relationship record storage."""
+
+ @abstractmethod
+ def add_image_to_board(
+ self,
+ board_id: str,
+ image_name: str,
+ ) -> None:
+ """Adds an image to a board."""
+ pass
+
+ @abstractmethod
+ def remove_image_from_board(
+ self,
+ board_id: str,
+ image_name: str,
+ ) -> None:
+ """Removes an image from a board."""
+ pass
+
+ @abstractmethod
+ def get_images_for_board(
+ self,
+ board_id: str,
+ ) -> OffsetPaginatedResults[ImageRecord]:
+ """Gets images for a board."""
+ pass
+
+ @abstractmethod
+ def get_board_for_image(
+ self,
+ image_name: str,
+ ) -> Union[str, None]:
+ """Gets an image's board id, if it has one."""
+ pass
+
+ @abstractmethod
+ def get_image_count_for_board(
+ self,
+ board_id: str,
+ ) -> int:
+ """Gets the number of images for a board."""
+ pass
+
+
+class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
+ _filename: str
+ _conn: sqlite3.Connection
+ _cursor: sqlite3.Cursor
+ _lock: threading.Lock
+
+ def __init__(self, filename: str) -> None:
+ super().__init__()
+ self._filename = filename
+ self._conn = sqlite3.connect(filename, check_same_thread=False)
+ # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
+ self._conn.row_factory = sqlite3.Row
+ self._cursor = self._conn.cursor()
+ self._lock = threading.Lock()
+
+ try:
+ self._lock.acquire()
+ # Enable foreign keys
+ self._conn.execute("PRAGMA foreign_keys = ON;")
+ self._create_tables()
+ self._conn.commit()
+ finally:
+ self._lock.release()
+
+ def _create_tables(self) -> None:
+ """Creates the `board_images` junction table."""
+
+ # Create the `board_images` junction table.
+ self._cursor.execute(
+ """--sql
+ CREATE TABLE IF NOT EXISTS board_images (
+ board_id TEXT NOT NULL,
+ image_name TEXT NOT NULL,
+ created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
+ -- updated via trigger
+ updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
+ -- Soft delete, currently unused
+ deleted_at DATETIME,
+ -- enforce one-to-many relationship between boards and images using PK
+ -- (we can extend this to many-to-many later)
+ PRIMARY KEY (image_name),
+ FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
+ FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
+ );
+ """
+ )
+
+ # Add index for board id
+ self._cursor.execute(
+ """--sql
+ CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
+ """
+ )
+
+ # Add index for board id, sorted by created_at
+ self._cursor.execute(
+ """--sql
+ CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
+ """
+ )
+
+ # Add trigger for `updated_at`.
+ self._cursor.execute(
+ """--sql
+ CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
+ AFTER UPDATE
+ ON board_images FOR EACH ROW
+ BEGIN
+ UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
+ WHERE board_id = old.board_id AND image_name = old.image_name;
+ END;
+ """
+ )
+
+ def add_image_to_board(
+ self,
+ board_id: str,
+ image_name: str,
+ ) -> None:
+ try:
+ self._lock.acquire()
+ self._cursor.execute(
+ """--sql
+ INSERT INTO board_images (board_id, image_name)
+ VALUES (?, ?)
+ ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
+ """,
+ (board_id, image_name, board_id),
+ )
+ self._conn.commit()
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise e
+ finally:
+ self._lock.release()
+
+ def remove_image_from_board(
+ self,
+ board_id: str,
+ image_name: str,
+ ) -> None:
+ try:
+ self._lock.acquire()
+ self._cursor.execute(
+ """--sql
+ DELETE FROM board_images
+ WHERE board_id = ? AND image_name = ?;
+ """,
+ (board_id, image_name),
+ )
+ self._conn.commit()
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise e
+ finally:
+ self._lock.release()
+
+ def get_images_for_board(
+ self,
+ board_id: str,
+ offset: int = 0,
+ limit: int = 10,
+ ) -> OffsetPaginatedResults[ImageRecord]:
+ # TODO: this isn't paginated yet?
+ try:
+ self._lock.acquire()
+ self._cursor.execute(
+ """--sql
+ SELECT images.*
+ FROM board_images
+ INNER JOIN images ON board_images.image_name = images.image_name
+ WHERE board_images.board_id = ?
+ ORDER BY board_images.updated_at DESC;
+ """,
+ (board_id,),
+ )
+ result = cast(list[sqlite3.Row], self._cursor.fetchall())
+ images = list(map(lambda r: deserialize_image_record(dict(r)), result))
+
+ self._cursor.execute(
+ """--sql
+ SELECT COUNT(*) FROM images WHERE 1=1;
+ """
+ )
+ count = cast(int, self._cursor.fetchone()[0])
+
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise e
+ finally:
+ self._lock.release()
+ return OffsetPaginatedResults(
+ items=images, offset=offset, limit=limit, total=count
+ )
+
+ def get_board_for_image(
+ self,
+ image_name: str,
+ ) -> Union[str, None]:
+ try:
+ self._lock.acquire()
+ self._cursor.execute(
+ """--sql
+ SELECT board_id
+ FROM board_images
+ WHERE image_name = ?;
+ """,
+ (image_name,),
+ )
+ result = self._cursor.fetchone()
+ if result is None:
+ return None
+ return cast(str, result[0])
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise e
+ finally:
+ self._lock.release()
+
+ def get_image_count_for_board(self, board_id: str) -> int:
+ try:
+ self._lock.acquire()
+ self._cursor.execute(
+ """--sql
+ SELECT COUNT(*) FROM board_images WHERE board_id = ?;
+ """,
+ (board_id,),
+ )
+ count = cast(int, self._cursor.fetchone()[0])
+ return count
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise e
+ finally:
+ self._lock.release()
diff --git a/invokeai/app/services/board_images.py b/invokeai/app/services/board_images.py
new file mode 100644
index 0000000000..072effbfae
--- /dev/null
+++ b/invokeai/app/services/board_images.py
@@ -0,0 +1,142 @@
+from abc import ABC, abstractmethod
+from logging import Logger
+from typing import List, Union
+from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
+from invokeai.app.services.board_record_storage import (
+ BoardRecord,
+ BoardRecordStorageBase,
+)
+
+from invokeai.app.services.image_record_storage import (
+ ImageRecordStorageBase,
+ OffsetPaginatedResults,
+)
+from invokeai.app.services.models.board_record import BoardDTO
+from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
+from invokeai.app.services.urls import UrlServiceBase
+
+
+class BoardImagesServiceABC(ABC):
+ """High-level service for board-image relationship management."""
+
+ @abstractmethod
+ def add_image_to_board(
+ self,
+ board_id: str,
+ image_name: str,
+ ) -> None:
+ """Adds an image to a board."""
+ pass
+
+ @abstractmethod
+ def remove_image_from_board(
+ self,
+ board_id: str,
+ image_name: str,
+ ) -> None:
+ """Removes an image from a board."""
+ pass
+
+ @abstractmethod
+ def get_images_for_board(
+ self,
+ board_id: str,
+ ) -> OffsetPaginatedResults[ImageDTO]:
+ """Gets images for a board."""
+ pass
+
+ @abstractmethod
+ def get_board_for_image(
+ self,
+ image_name: str,
+ ) -> Union[str, None]:
+ """Gets an image's board id, if it has one."""
+ pass
+
+
+class BoardImagesServiceDependencies:
+ """Service dependencies for the BoardImagesService."""
+
+ board_image_records: BoardImageRecordStorageBase
+ board_records: BoardRecordStorageBase
+ image_records: ImageRecordStorageBase
+ urls: UrlServiceBase
+ logger: Logger
+
+ def __init__(
+ self,
+ board_image_record_storage: BoardImageRecordStorageBase,
+ image_record_storage: ImageRecordStorageBase,
+ board_record_storage: BoardRecordStorageBase,
+ url: UrlServiceBase,
+ logger: Logger,
+ ):
+ self.board_image_records = board_image_record_storage
+ self.image_records = image_record_storage
+ self.board_records = board_record_storage
+ self.urls = url
+ self.logger = logger
+
+
+class BoardImagesService(BoardImagesServiceABC):
+ _services: BoardImagesServiceDependencies
+
+ def __init__(self, services: BoardImagesServiceDependencies):
+ self._services = services
+
+ def add_image_to_board(
+ self,
+ board_id: str,
+ image_name: str,
+ ) -> None:
+ self._services.board_image_records.add_image_to_board(board_id, image_name)
+
+ def remove_image_from_board(
+ self,
+ board_id: str,
+ image_name: str,
+ ) -> None:
+ self._services.board_image_records.remove_image_from_board(board_id, image_name)
+
+ def get_images_for_board(
+ self,
+ board_id: str,
+ ) -> OffsetPaginatedResults[ImageDTO]:
+ image_records = self._services.board_image_records.get_images_for_board(
+ board_id
+ )
+ image_dtos = list(
+ map(
+ lambda r: image_record_to_dto(
+ r,
+ self._services.urls.get_image_url(r.image_name),
+ self._services.urls.get_image_url(r.image_name, True),
+ board_id,
+ ),
+ image_records.items,
+ )
+ )
+ return OffsetPaginatedResults[ImageDTO](
+ items=image_dtos,
+ offset=image_records.offset,
+ limit=image_records.limit,
+ total=image_records.total,
+ )
+
+ def get_board_for_image(
+ self,
+ image_name: str,
+ ) -> Union[str, None]:
+ board_id = self._services.board_image_records.get_board_for_image(image_name)
+ return board_id
+
+
+def board_record_to_dto(
+ board_record: BoardRecord, cover_image_name: str | None, image_count: int
+) -> BoardDTO:
+ """Converts a board record to a board DTO."""
+ return BoardDTO(
+ **board_record.dict(exclude={'cover_image_name'}),
+ cover_image_name=cover_image_name,
+ image_count=image_count,
+ )
diff --git a/invokeai/app/services/board_record_storage.py b/invokeai/app/services/board_record_storage.py
new file mode 100644
index 0000000000..15ea9cc5a7
--- /dev/null
+++ b/invokeai/app/services/board_record_storage.py
@@ -0,0 +1,329 @@
+from abc import ABC, abstractmethod
+from typing import Optional, cast
+import sqlite3
+import threading
+from typing import Optional, Union
+import uuid
+from invokeai.app.services.image_record_storage import OffsetPaginatedResults
+from invokeai.app.services.models.board_record import (
+ BoardRecord,
+ deserialize_board_record,
+)
+
+from pydantic import BaseModel, Field, Extra
+
+
+class BoardChanges(BaseModel, extra=Extra.forbid):
+ board_name: Optional[str] = Field(description="The board's new name.")
+ cover_image_name: Optional[str] = Field(
+ description="The name of the board's new cover image."
+ )
+
+
+class BoardRecordNotFoundException(Exception):
+ """Raised when an board record is not found."""
+
+ def __init__(self, message="Board record not found"):
+ super().__init__(message)
+
+
+class BoardRecordSaveException(Exception):
+ """Raised when an board record cannot be saved."""
+
+ def __init__(self, message="Board record not saved"):
+ super().__init__(message)
+
+
+class BoardRecordDeleteException(Exception):
+ """Raised when an board record cannot be deleted."""
+
+ def __init__(self, message="Board record not deleted"):
+ super().__init__(message)
+
+
+class BoardRecordStorageBase(ABC):
+ """Low-level service responsible for interfacing with the board record store."""
+
+ @abstractmethod
+ def delete(self, board_id: str) -> None:
+ """Deletes a board record."""
+ pass
+
+ @abstractmethod
+ def save(
+ self,
+ board_name: str,
+ ) -> BoardRecord:
+ """Saves a board record."""
+ pass
+
+ @abstractmethod
+ def get(
+ self,
+ board_id: str,
+ ) -> BoardRecord:
+ """Gets a board record."""
+ pass
+
+ @abstractmethod
+ def update(
+ self,
+ board_id: str,
+ changes: BoardChanges,
+ ) -> BoardRecord:
+ """Updates a board record."""
+ pass
+
+ @abstractmethod
+ def get_many(
+ self,
+ offset: int = 0,
+ limit: int = 10,
+ ) -> OffsetPaginatedResults[BoardRecord]:
+ """Gets many board records."""
+ pass
+
+ @abstractmethod
+ def get_all(
+ self,
+ ) -> list[BoardRecord]:
+ """Gets all board records."""
+ pass
+
+
+class SqliteBoardRecordStorage(BoardRecordStorageBase):
+ _filename: str
+ _conn: sqlite3.Connection
+ _cursor: sqlite3.Cursor
+ _lock: threading.Lock
+
+ def __init__(self, filename: str) -> None:
+ super().__init__()
+ self._filename = filename
+ self._conn = sqlite3.connect(filename, check_same_thread=False)
+ # Enable row factory to get rows as dictionaries (must be done before making the cursor!)
+ self._conn.row_factory = sqlite3.Row
+ self._cursor = self._conn.cursor()
+ self._lock = threading.Lock()
+
+ try:
+ self._lock.acquire()
+ # Enable foreign keys
+ self._conn.execute("PRAGMA foreign_keys = ON;")
+ self._create_tables()
+ self._conn.commit()
+ finally:
+ self._lock.release()
+
+ def _create_tables(self) -> None:
+ """Creates the `boards` table and `board_images` junction table."""
+
+ # Create the `boards` table.
+ self._cursor.execute(
+ """--sql
+ CREATE TABLE IF NOT EXISTS boards (
+ board_id TEXT NOT NULL PRIMARY KEY,
+ board_name TEXT NOT NULL,
+ cover_image_name TEXT,
+ created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
+ -- Updated via trigger
+ updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
+ -- Soft delete, currently unused
+ deleted_at DATETIME,
+ FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
+ );
+ """
+ )
+
+ self._cursor.execute(
+ """--sql
+ CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
+ """
+ )
+
+ # Add trigger for `updated_at`.
+ self._cursor.execute(
+ """--sql
+ CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
+ AFTER UPDATE
+ ON boards FOR EACH ROW
+ BEGIN
+ UPDATE boards SET updated_at = current_timestamp
+ WHERE board_id = old.board_id;
+ END;
+ """
+ )
+
+ def delete(self, board_id: str) -> None:
+ try:
+ self._lock.acquire()
+ self._cursor.execute(
+ """--sql
+ DELETE FROM boards
+ WHERE board_id = ?;
+ """,
+ (board_id,),
+ )
+ self._conn.commit()
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise BoardRecordDeleteException from e
+ except Exception as e:
+ self._conn.rollback()
+ raise BoardRecordDeleteException from e
+ finally:
+ self._lock.release()
+
+ def save(
+ self,
+ board_name: str,
+ ) -> BoardRecord:
+ try:
+ board_id = str(uuid.uuid4())
+ self._lock.acquire()
+ self._cursor.execute(
+ """--sql
+ INSERT OR IGNORE INTO boards (board_id, board_name)
+ VALUES (?, ?);
+ """,
+ (board_id, board_name),
+ )
+ self._conn.commit()
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise BoardRecordSaveException from e
+ finally:
+ self._lock.release()
+ return self.get(board_id)
+
+ def get(
+ self,
+ board_id: str,
+ ) -> BoardRecord:
+ try:
+ self._lock.acquire()
+ self._cursor.execute(
+ """--sql
+ SELECT *
+ FROM boards
+ WHERE board_id = ?;
+ """,
+ (board_id,),
+ )
+
+ result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise BoardRecordNotFoundException from e
+ finally:
+ self._lock.release()
+ if result is None:
+ raise BoardRecordNotFoundException
+ return BoardRecord(**dict(result))
+
+ def update(
+ self,
+ board_id: str,
+ changes: BoardChanges,
+ ) -> BoardRecord:
+ try:
+ self._lock.acquire()
+
+ # Change the name of a board
+ if changes.board_name is not None:
+ self._cursor.execute(
+ f"""--sql
+ UPDATE boards
+ SET board_name = ?
+ WHERE board_id = ?;
+ """,
+ (changes.board_name, board_id),
+ )
+
+ # Change the cover image of a board
+ if changes.cover_image_name is not None:
+ self._cursor.execute(
+ f"""--sql
+ UPDATE boards
+ SET cover_image_name = ?
+ WHERE board_id = ?;
+ """,
+ (changes.cover_image_name, board_id),
+ )
+
+ self._conn.commit()
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise BoardRecordSaveException from e
+ finally:
+ self._lock.release()
+ return self.get(board_id)
+
+ def get_many(
+ self,
+ offset: int = 0,
+ limit: int = 10,
+ ) -> OffsetPaginatedResults[BoardRecord]:
+ try:
+ self._lock.acquire()
+
+ # Get all the boards
+ self._cursor.execute(
+ """--sql
+ SELECT *
+ FROM boards
+ ORDER BY created_at DESC
+ LIMIT ? OFFSET ?;
+ """,
+ (limit, offset),
+ )
+
+ result = cast(list[sqlite3.Row], self._cursor.fetchall())
+ boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
+
+ # Get the total number of boards
+ self._cursor.execute(
+ """--sql
+ SELECT COUNT(*)
+ FROM boards
+ WHERE 1=1;
+ """
+ )
+
+ count = cast(int, self._cursor.fetchone()[0])
+
+ return OffsetPaginatedResults[BoardRecord](
+ items=boards, offset=offset, limit=limit, total=count
+ )
+
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise e
+ finally:
+ self._lock.release()
+
+ def get_all(
+ self,
+ ) -> list[BoardRecord]:
+ try:
+ self._lock.acquire()
+
+ # Get all the boards
+ self._cursor.execute(
+ """--sql
+ SELECT *
+ FROM boards
+ ORDER BY created_at DESC
+ """
+ )
+
+ result = cast(list[sqlite3.Row], self._cursor.fetchall())
+ boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
+
+ return boards
+
+ except sqlite3.Error as e:
+ self._conn.rollback()
+ raise e
+ finally:
+ self._lock.release()
diff --git a/invokeai/app/services/boards.py b/invokeai/app/services/boards.py
new file mode 100644
index 0000000000..9361322e6c
--- /dev/null
+++ b/invokeai/app/services/boards.py
@@ -0,0 +1,185 @@
+from abc import ABC, abstractmethod
+
+from logging import Logger
+from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
+from invokeai.app.services.board_images import board_record_to_dto
+
+from invokeai.app.services.board_record_storage import (
+ BoardChanges,
+ BoardRecordStorageBase,
+)
+from invokeai.app.services.image_record_storage import (
+ ImageRecordStorageBase,
+ OffsetPaginatedResults,
+)
+from invokeai.app.services.models.board_record import BoardDTO
+from invokeai.app.services.urls import UrlServiceBase
+
+
+class BoardServiceABC(ABC):
+ """High-level service for board management."""
+
+ @abstractmethod
+ def create(
+ self,
+ board_name: str,
+ ) -> BoardDTO:
+ """Creates a board."""
+ pass
+
+ @abstractmethod
+ def get_dto(
+ self,
+ board_id: str,
+ ) -> BoardDTO:
+ """Gets a board."""
+ pass
+
+ @abstractmethod
+ def update(
+ self,
+ board_id: str,
+ changes: BoardChanges,
+ ) -> BoardDTO:
+ """Updates a board."""
+ pass
+
+ @abstractmethod
+ def delete(
+ self,
+ board_id: str,
+ ) -> None:
+ """Deletes a board."""
+ pass
+
+ @abstractmethod
+ def get_many(
+ self,
+ offset: int = 0,
+ limit: int = 10,
+ ) -> OffsetPaginatedResults[BoardDTO]:
+ """Gets many boards."""
+ pass
+
+ @abstractmethod
+ def get_all(
+ self,
+ ) -> list[BoardDTO]:
+ """Gets all boards."""
+ pass
+
+
+class BoardServiceDependencies:
+ """Service dependencies for the BoardService."""
+
+ board_image_records: BoardImageRecordStorageBase
+ board_records: BoardRecordStorageBase
+ image_records: ImageRecordStorageBase
+ urls: UrlServiceBase
+ logger: Logger
+
+ def __init__(
+ self,
+ board_image_record_storage: BoardImageRecordStorageBase,
+ image_record_storage: ImageRecordStorageBase,
+ board_record_storage: BoardRecordStorageBase,
+ url: UrlServiceBase,
+ logger: Logger,
+ ):
+ self.board_image_records = board_image_record_storage
+ self.image_records = image_record_storage
+ self.board_records = board_record_storage
+ self.urls = url
+ self.logger = logger
+
+
+class BoardService(BoardServiceABC):
+ _services: BoardServiceDependencies
+
+ def __init__(self, services: BoardServiceDependencies):
+ self._services = services
+
+ def create(
+ self,
+ board_name: str,
+ ) -> BoardDTO:
+ board_record = self._services.board_records.save(board_name)
+ return board_record_to_dto(board_record, None, 0)
+
+ def get_dto(self, board_id: str) -> BoardDTO:
+ board_record = self._services.board_records.get(board_id)
+ cover_image = self._services.image_records.get_most_recent_image_for_board(
+ board_record.board_id
+ )
+ if cover_image:
+ cover_image_name = cover_image.image_name
+ else:
+ cover_image_name = None
+ image_count = self._services.board_image_records.get_image_count_for_board(
+ board_id
+ )
+ return board_record_to_dto(board_record, cover_image_name, image_count)
+
+ def update(
+ self,
+ board_id: str,
+ changes: BoardChanges,
+ ) -> BoardDTO:
+ board_record = self._services.board_records.update(board_id, changes)
+ cover_image = self._services.image_records.get_most_recent_image_for_board(
+ board_record.board_id
+ )
+ if cover_image:
+ cover_image_name = cover_image.image_name
+ else:
+ cover_image_name = None
+
+ image_count = self._services.board_image_records.get_image_count_for_board(
+ board_id
+ )
+ return board_record_to_dto(board_record, cover_image_name, image_count)
+
+ def delete(self, board_id: str) -> None:
+ self._services.board_records.delete(board_id)
+
+ def get_many(
+ self, offset: int = 0, limit: int = 10
+ ) -> OffsetPaginatedResults[BoardDTO]:
+ board_records = self._services.board_records.get_many(offset, limit)
+ board_dtos = []
+ for r in board_records.items:
+ cover_image = self._services.image_records.get_most_recent_image_for_board(
+ r.board_id
+ )
+ if cover_image:
+ cover_image_name = cover_image.image_name
+ else:
+ cover_image_name = None
+
+ image_count = self._services.board_image_records.get_image_count_for_board(
+ r.board_id
+ )
+ board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
+
+ return OffsetPaginatedResults[BoardDTO](
+ items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
+ )
+
+ def get_all(self) -> list[BoardDTO]:
+ board_records = self._services.board_records.get_all()
+ board_dtos = []
+ for r in board_records:
+ cover_image = self._services.image_records.get_most_recent_image_for_board(
+ r.board_id
+ )
+ if cover_image:
+ cover_image_name = cover_image.image_name
+ else:
+ cover_image_name = None
+
+ image_count = self._services.board_image_records.get_image_count_for_board(
+ r.board_id
+ )
+ board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
+
+ return board_dtos
\ No newline at end of file
diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py
index 30b379ed8b..c34d2ca5c8 100644
--- a/invokeai/app/services/image_record_storage.py
+++ b/invokeai/app/services/image_record_storage.py
@@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
+ board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC):
"""Saves an image record."""
pass
+ @abstractmethod
+ def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
+ """Gets the most recent image for a board."""
+ pass
+
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
@@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._lock.release()
def _create_tables(self) -> None:
- """Creates the tables for the `images` database."""
+ """Creates the `images` table."""
# Create the `images` table.
self._cursor.execute(
@@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
+ board_id TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@@ -190,7 +197,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
- UPDATE images SET updated_at = current_timestamp
+ UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
@@ -259,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""",
(changes.is_intermediate, image_name),
)
+
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@@ -273,38 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
+ board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
# Manually build two queries - one for the count, one for the records
+ count_query = """--sql
+ SELECT COUNT(*)
+ FROM images
+ LEFT JOIN board_images ON board_images.image_name = images.image_name
+ WHERE 1=1
+ """
- count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
- images_query = f"""SELECT * FROM images WHERE 1=1\n"""
+ images_query = """--sql
+ SELECT images.*
+ FROM images
+ LEFT JOIN board_images ON board_images.image_name = images.image_name
+ WHERE 1=1
+ """
query_conditions = ""
query_params = []
if image_origin is not None:
- query_conditions += f"""AND image_origin = ?\n"""
+ query_conditions += """--sql
+ AND images.image_origin = ?
+ """
query_params.append(image_origin.value)
if categories is not None:
- ## Convert the enum values to unique list of strings
+ # Convert the enum values to unique list of strings
category_strings = list(map(lambda c: c.value, set(categories)))
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
- query_conditions += f"AND image_category IN ( {placeholders} )\n"
+
+ query_conditions += f"""--sql
+ AND images.image_category IN ( {placeholders} )
+ """
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
- query_conditions += f"""AND is_intermediate = ?\n"""
+ query_conditions += """--sql
+ AND images.is_intermediate = ?
+ """
+
query_params.append(is_intermediate)
- query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
+ if board_id is not None:
+ query_conditions += """--sql
+ AND board_images.board_id = ?
+ """
+
+ query_params.append(board_id)
+
+ query_pagination = """--sql
+ ORDER BY images.created_at DESC LIMIT ? OFFSET ?
+ """
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
@@ -321,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
- count = self._cursor.fetchone()[0]
+ count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
@@ -412,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
raise ImageRecordSaveException from e
finally:
self._lock.release()
+
+ def get_most_recent_image_for_board(
+ self, board_id: str
+ ) -> Union[ImageRecord, None]:
+ try:
+ self._lock.acquire()
+ self._cursor.execute(
+ """--sql
+ SELECT images.*
+ FROM images
+ JOIN board_images ON images.image_name = board_images.image_name
+ WHERE board_images.board_id = ?
+ ORDER BY images.created_at DESC
+ LIMIT 1;
+ """,
+ (board_id,),
+ )
+
+ result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
+ finally:
+ self._lock.release()
+ if result is None:
+ return None
+
+ return deserialize_image_record(dict(result))
diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py
index 9f7188f607..542f874f1d 100644
--- a/invokeai/app/services/images.py
+++ b/invokeai/app/services/images.py
@@ -10,6 +10,7 @@ from invokeai.app.models.image import (
InvalidOriginException,
)
from invokeai.app.models.metadata import ImageMetadata
+from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
@@ -49,7 +50,7 @@ class ImageServiceABC(ABC):
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
- intermediate: bool = False,
+ is_intermediate: bool = False,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@@ -79,7 +80,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
- def get_path(self, image_name: str) -> str:
+ def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path."""
pass
@@ -101,6 +102,7 @@ class ImageServiceABC(ABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
+ board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@@ -114,8 +116,9 @@ class ImageServiceABC(ABC):
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
- records: ImageRecordStorageBase
- files: ImageFileStorageBase
+ image_records: ImageRecordStorageBase
+ image_files: ImageFileStorageBase
+ board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
@@ -126,14 +129,16 @@ class ImageServiceDependencies:
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
+ board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
- self.records = image_record_storage
- self.files = image_file_storage
+ self.image_records = image_record_storage
+ self.image_files = image_file_storage
+ self.board_image_records = board_image_record_storage
self.metadata = metadata
self.urls = url
self.logger = logger
@@ -144,25 +149,8 @@ class ImageServiceDependencies:
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
- def __init__(
- self,
- image_record_storage: ImageRecordStorageBase,
- image_file_storage: ImageFileStorageBase,
- metadata: MetadataServiceBase,
- url: UrlServiceBase,
- logger: Logger,
- names: NameServiceBase,
- graph_execution_manager: ItemStorageABC["GraphExecutionState"],
- ):
- self._services = ImageServiceDependencies(
- image_record_storage=image_record_storage,
- image_file_storage=image_file_storage,
- metadata=metadata,
- url=url,
- logger=logger,
- names=names,
- graph_execution_manager=graph_execution_manager,
- )
+ def __init__(self, services: ImageServiceDependencies):
+ self._services = services
def create(
self,
@@ -187,7 +175,7 @@ class ImageService(ImageServiceABC):
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
- created_at = self._services.records.save(
+ self._services.image_records.save(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
@@ -202,35 +190,15 @@ class ImageService(ImageServiceABC):
metadata=metadata,
)
- self._services.files.save(
+ self._services.image_files.save(
image_name=image_name,
image=image,
metadata=metadata,
)
- image_url = self._services.urls.get_image_url(image_name)
- thumbnail_url = self._services.urls.get_image_url(image_name, True)
+ image_dto = self.get_dto(image_name)
- return ImageDTO(
- # Non-nullable fields
- image_name=image_name,
- image_origin=image_origin,
- image_category=image_category,
- width=width,
- height=height,
- # Nullable fields
- node_id=node_id,
- session_id=session_id,
- metadata=metadata,
- # Meta fields
- created_at=created_at,
- updated_at=created_at, # this is always the same as the created_at at this time
- deleted_at=None,
- is_intermediate=is_intermediate,
- # Extra non-nullable fields for DTO
- image_url=image_url,
- thumbnail_url=thumbnail_url,
- )
+ return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
@@ -247,7 +215,7 @@ class ImageService(ImageServiceABC):
changes: ImageRecordChanges,
) -> ImageDTO:
try:
- self._services.records.update(image_name, changes)
+ self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
@@ -258,7 +226,7 @@ class ImageService(ImageServiceABC):
def get_pil_image(self, image_name: str) -> PILImageType:
try:
- return self._services.files.get(image_name)
+ return self._services.image_files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
@@ -268,7 +236,7 @@ class ImageService(ImageServiceABC):
def get_record(self, image_name: str) -> ImageRecord:
try:
- return self._services.records.get(image_name)
+ return self._services.image_records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
@@ -278,12 +246,13 @@ class ImageService(ImageServiceABC):
def get_dto(self, image_name: str) -> ImageDTO:
try:
- image_record = self._services.records.get(image_name)
+ image_record = self._services.image_records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
+ self._services.board_image_records.get_board_for_image(image_name),
)
return image_dto
@@ -296,14 +265,14 @@ class ImageService(ImageServiceABC):
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
- return self._services.files.get_path(image_name, thumbnail)
+ return self._services.image_files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
- return self._services.files.validate_path(path)
+ return self._services.image_files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
@@ -322,14 +291,16 @@ class ImageService(ImageServiceABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
+ board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
- results = self._services.records.get_many(
+ results = self._services.image_records.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
+ board_id,
)
image_dtos = list(
@@ -338,6 +309,9 @@ class ImageService(ImageServiceABC):
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
+ self._services.board_image_records.get_board_for_image(
+ r.image_name
+ ),
),
results.items,
)
@@ -355,8 +329,8 @@ class ImageService(ImageServiceABC):
def delete(self, image_name: str):
try:
- self._services.files.delete(image_name)
- self._services.records.delete(image_name)
+ self._services.image_files.delete(image_name)
+ self._services.image_records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py
index 1f910253e5..10d1d91920 100644
--- a/invokeai/app/services/invocation_services.py
+++ b/invokeai/app/services/invocation_services.py
@@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
- from invokeai.app.services.images import ImageService
+ from invokeai.app.services.board_images import BoardImagesServiceABC
+ from invokeai.app.services.boards import BoardServiceABC
+ from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
@@ -26,9 +28,9 @@ class InvocationServices:
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
- images: "ImageService"
-
- # NOTE: we must forward-declare any types that include invocations, since invocations can use services
+ images: "ImageServiceABC"
+ boards: "BoardServiceABC"
+ board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC"
@@ -39,7 +41,9 @@ class InvocationServices:
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
- images: "ImageService",
+ images: "ImageServiceABC",
+ boards: "BoardServiceABC",
+ board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
@@ -52,9 +56,12 @@ class InvocationServices:
self.logger = logger
self.latents = latents
self.images = images
+ self.boards = boards
+ self.board_images = board_images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration
+ self.boards = boards
diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py
index c212ff6a72..8b46b17ad0 100644
--- a/invokeai/app/services/model_manager_service.py
+++ b/invokeai/app/services/model_manager_service.py
@@ -5,7 +5,7 @@ from __future__ import annotations
import torch
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
+from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import (
@@ -69,19 +69,6 @@ class ModelManagerServiceBase(ABC):
) -> bool:
pass
- @abstractmethod
- def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
- """
- Returns the name and typeof the default model, or None
- if none is defined.
- """
- pass
-
- @abstractmethod
- def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
- """Sets the default model to the indicated name."""
- pass
-
@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
@@ -270,17 +257,6 @@ class ModelManagerService(ModelManagerServiceBase):
model_type,
)
- def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
- """
- Returns the name of the default model, or None
- if none is defined.
- """
- return self.mgr.default_model()
-
- def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
- """Sets the default model to the indicated name."""
- self.mgr.set_default_model(model_name, base_model, model_type)
-
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
@@ -297,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
- ) -> dict:
+ ) -> list[dict]:
+ # ) -> dict:
"""
- Return a dict of models in the format:
- { model_type1:
- { model_name1: {'status': 'active'|'cached'|'not loaded',
- 'model_name' : name,
- 'model_type' : SDModelType,
- 'description': description,
- 'format': 'folder'|'safetensors'|'ckpt'
- },
- model_name2: { etc }
- },
- model_type2:
- { model_name_n: etc
- }
+ Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)
diff --git a/invokeai/app/services/models/board_record.py b/invokeai/app/services/models/board_record.py
new file mode 100644
index 0000000000..bf5401b209
--- /dev/null
+++ b/invokeai/app/services/models/board_record.py
@@ -0,0 +1,62 @@
+from typing import Optional, Union
+from datetime import datetime
+from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
+from invokeai.app.util.misc import get_iso_timestamp
+
+
+class BoardRecord(BaseModel):
+ """Deserialized board record."""
+
+ board_id: str = Field(description="The unique ID of the board.")
+ """The unique ID of the board."""
+ board_name: str = Field(description="The name of the board.")
+ """The name of the board."""
+ created_at: Union[datetime, str] = Field(
+ description="The created timestamp of the board."
+ )
+ """The created timestamp of the image."""
+ updated_at: Union[datetime, str] = Field(
+ description="The updated timestamp of the board."
+ )
+ """The updated timestamp of the image."""
+ deleted_at: Union[datetime, str, None] = Field(
+ description="The deleted timestamp of the board."
+ )
+ """The updated timestamp of the image."""
+ cover_image_name: Optional[str] = Field(
+ description="The name of the cover image of the board."
+ )
+ """The name of the cover image of the board."""
+
+
+class BoardDTO(BoardRecord):
+ """Deserialized board record with cover image URL and image count."""
+
+ cover_image_name: Optional[str] = Field(
+ description="The name of the board's cover image."
+ )
+ """The URL of the thumbnail of the most recent image in the board."""
+ image_count: int = Field(description="The number of images in the board.")
+ """The number of images in the board."""
+
+
+def deserialize_board_record(board_dict: dict) -> BoardRecord:
+ """Deserializes a board record."""
+
+ # Retrieve all the values, setting "reasonable" defaults if they are not present.
+
+ board_id = board_dict.get("board_id", "unknown")
+ board_name = board_dict.get("board_name", "unknown")
+ cover_image_name = board_dict.get("cover_image_name", "unknown")
+ created_at = board_dict.get("created_at", get_iso_timestamp())
+ updated_at = board_dict.get("updated_at", get_iso_timestamp())
+ deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
+
+ return BoardRecord(
+ board_id=board_id,
+ board_name=board_name,
+ cover_image_name=cover_image_name,
+ created_at=created_at,
+ updated_at=updated_at,
+ deleted_at=deleted_at,
+ )
diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py
index d971d65916..cc02016cf9 100644
--- a/invokeai/app/services/models/image_record.py
+++ b/invokeai/app/services/models/image_record.py
@@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO):
- """Deserialized image record, enriched for the frontend with URLs."""
+ """Deserialized image record, enriched for the frontend."""
+ board_id: Union[str, None] = Field(
+ description="The id of the board the image belongs to, if one exists."
+ )
+ """The id of the board the image belongs to, if one exists."""
pass
def image_record_to_dto(
- image_record: ImageRecord, image_url: str, thumbnail_url: str
+ image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
**image_record.dict(),
image_url=image_url,
thumbnail_url=thumbnail_url,
+ board_id=board_id,
)
diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py
index e56ec2a0d2..4e2c789c07 100644
--- a/invokeai/backend/model_management/model_manager.py
+++ b/invokeai/backend/model_management/model_manager.py
@@ -266,6 +266,8 @@ class ModelManager(object):
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
+ # alias for config file
+ model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
@@ -446,38 +448,6 @@ class ModelManager(object):
_cache = self.cache,
)
- def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
- """
- Returns the name of the default model, or None
- if none is defined.
- """
- for model_key, model_config in self.models.items():
- if model_config.default:
- return self.parse_key(model_key)
-
- for model_key, _ in self.models.items():
- return self.parse_key(model_key)
- else:
- return None # TODO: or redo as (None, None, None)
-
- def set_default_model(
- self,
- model_name: str,
- base_model: BaseModelType,
- model_type: ModelType,
- ) -> None:
- """
- Set the default model. The change will not take
- effect until you call model_manager.commit()
- """
-
- model_key = self.model_key(model_name, base_model, model_type)
- if model_key not in self.models:
- raise Exception(f"Unknown model: {model_key}")
-
- for cur_model_key, config in self.models.items():
- config.default = cur_model_key == model_key
-
def model_info(
self,
model_name: str,
@@ -504,9 +474,9 @@ class ModelManager(object):
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
- ) -> Dict[str, Dict[str, str]]:
+ ) -> list[dict]:
"""
- Return a dict of models, in format [base_model][model_type][model_name]
+ Return a list of models.
Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model
@@ -514,7 +484,7 @@ class ModelManager(object):
object derived from models.yaml
"""
- models = dict()
+ models = []
for model_key in sorted(self.models, key=str.casefold):
model_config = self.models[model_key]
@@ -524,18 +494,16 @@ class ModelManager(object):
if model_type is not None and cur_model_type != model_type:
continue
- if cur_base_model not in models:
- models[cur_base_model] = dict()
- if cur_model_type not in models[cur_base_model]:
- models[cur_base_model][cur_model_type] = dict()
-
- models[cur_base_model][cur_model_type][cur_model_name] = dict(
+ model_dict = dict(
**model_config.dict(exclude_defaults=True),
+ # OpenAPIModelInfoBase
name=cur_model_name,
base_model=cur_base_model,
type=cur_model_type,
)
+ models.append(model_dict)
+
return models
def print_models(self) -> None:
@@ -647,7 +615,9 @@ class ModelManager(object):
model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
- data_to_save[model_key] = model_config.dict(exclude_defaults=True)
+ data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
+ # alias for config file
+ data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path
diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py
index 40995498bf..6975d45f93 100644
--- a/invokeai/backend/model_management/models/__init__.py
+++ b/invokeai/backend/model_management/models/__init__.py
@@ -1,3 +1,7 @@
+import inspect
+from enum import Enum
+from pydantic import BaseModel
+from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel
@@ -29,10 +33,63 @@ MODEL_CLASSES = {
#},
}
-def get_all_model_configs():
- configs = set()
- for models in MODEL_CLASSES.values():
- for _, model in models.items():
- configs.update(model._get_configs().values())
- configs.discard(None)
- return list(configs) # TODO: set, list or tuple
+MODEL_CONFIGS = list()
+OPENAPI_MODEL_CONFIGS = list()
+
+class OpenAPIModelInfoBase(BaseModel):
+ name: str
+ base_model: BaseModelType
+ type: ModelType
+
+
+for base_model, models in MODEL_CLASSES.items():
+ for model_type, model_class in models.items():
+ model_configs = set(model_class._get_configs().values())
+ model_configs.discard(None)
+ MODEL_CONFIGS.extend(model_configs)
+
+ for cfg in model_configs:
+ model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
+ openapi_cfg_name = model_name + cfg_name
+ if openapi_cfg_name in vars():
+ continue
+
+ api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
+ __annotations__ = dict(
+ type=Literal[model_type.value],
+ ),
+ ))
+
+ #globals()[openapi_cfg_name] = api_wrapper
+ vars()[openapi_cfg_name] = api_wrapper
+ OPENAPI_MODEL_CONFIGS.append(api_wrapper)
+
+def get_model_config_enums():
+ enums = list()
+
+ for model_config in MODEL_CONFIGS:
+ fields = inspect.get_annotations(model_config)
+ try:
+ field = fields["model_format"]
+ except:
+ raise Exception("format field not found")
+
+ # model_format: None
+ # model_format: SomeModelFormat
+ # model_format: Literal[SomeModelFormat.Diffusers]
+ # model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
+
+ if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
+ enums.append(field)
+
+ elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
+ enums.append(type(field.__args__[0]))
+
+ elif field is None:
+ pass
+
+ else:
+ raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
+
+ return enums
+
diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py
index f18099b4e7..ef354ecc07 100644
--- a/invokeai/backend/model_management/models/base.py
+++ b/invokeai/backend/model_management/models/base.py
@@ -48,12 +48,10 @@ class ModelError(str, Enum):
class ModelConfigBase(BaseModel):
path: str # or Path
- #name: str # not included as present in model key
description: Optional[str] = Field(None)
- format: Optional[str] = Field(None)
- default: Optional[bool] = Field(False)
+ model_format: Optional[str] = Field(None)
# do not save to config
- error: Optional[ModelError] = Field(None, exclude=True)
+ error: Optional[ModelError] = Field(None)
class Config:
use_enum_values = True
@@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta):
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
+ if all(t is None for t in subtypes):
+ return None
+ elif any(t is None for t in subtypes):
+ raise Exception(f"Unsupported definition: {subtypes}")
+
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
@@ -122,46 +125,41 @@ class ModelBase(metaclass=ABCMeta):
continue
fields = inspect.get_annotations(value)
- if "format" not in fields:
- raise Exception("Invalid config definition - format field not found")
+ try:
+ field = fields["model_format"]
+ except:
+ raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
- format_type = typing.get_origin(fields["format"])
- if format_type not in {None, Literal, Union}:
- raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
+ if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
+ for model_format in field:
+ configs[model_format.value] = value
- if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
- raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
+ elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
+ for model_format in field.__args__:
+ configs[model_format.value] = value
+
+ elif field is None:
+ configs[None] = value
- if format_type == Union:
- f_fields = fields["format"].__args__
else:
- f_fields = (fields["format"],)
-
-
- for field in f_fields:
- if field is None:
- format_name = None
- else:
- format_name = field.__args__[0]
-
- configs[format_name] = value # TODO: error when override(multiple)?
-
+ raise Exception(f"Unsupported format definition in {cls.__qualname__}")
cls.__configs = configs
return cls.__configs
@classmethod
def create_config(cls, **kwargs) -> ModelConfigBase:
- if "format" not in kwargs:
- raise Exception("Field 'format' not found in model config")
+ if "model_format" not in kwargs:
+ raise Exception("Field 'model_format' not found in model config")
+
configs = cls._get_configs()
- return configs[kwargs["format"]](**kwargs)
+ return configs[kwargs["model_format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
- format=cls.detect_format(path),
+ model_format=cls.detect_format(path),
)
@classmethod
diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py
index d75c55010a..9563f87afd 100644
--- a/invokeai/backend/model_management/models/controlnet.py
+++ b/invokeai/backend/model_management/models/controlnet.py
@@ -1,5 +1,6 @@
import os
import torch
+from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
@@ -14,12 +15,16 @@ from .base import (
classproperty,
)
+class ControlNetModelFormat(str, Enum):
+ Checkpoint = "checkpoint"
+ Diffusers = "diffusers"
+
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
- format: Union[Literal["checkpoint"], Literal["diffusers"]]
+ model_format: ControlNetModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
@@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
- return "diffusers"
+ return ControlNetModelFormat.Diffusers
else:
- return "checkpoint"
+ return ControlNetModelFormat.Checkpoint
@classmethod
def convert_if_required(
@@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
- if cls.detect_format(model_path) != "diffusers":
- raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
+ if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
+ raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path
diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py
index bcf3224ece..59feacde06 100644
--- a/invokeai/backend/model_management/models/lora.py
+++ b/invokeai/backend/model_management/models/lora.py
@@ -1,5 +1,6 @@
import os
import torch
+from enum import Enum
from typing import Optional, Union, Literal
from .base import (
ModelBase,
@@ -12,11 +13,15 @@ from .base import (
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
+class LoRAModelFormat(str, Enum):
+ LyCORIS = "lycoris"
+ Diffusers = "diffusers"
+
class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
- format: Union[Literal["lycoris"], Literal["diffusers"]]
+ model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
@@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
- return "diffusers"
+ return LoRAModelFormat.Diffusers
else:
- return "lycoris"
+ return LoRAModelFormat.LyCORIS
@classmethod
def convert_if_required(
@@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
- if cls.detect_format(model_path) == "diffusers":
+ if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
else:
diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py
index bd519c88c8..f169326571 100644
--- a/invokeai/backend/model_management/models/stable_diffusion.py
+++ b/invokeai/backend/model_management/models/stable_diffusion.py
@@ -1,5 +1,6 @@
import os
import json
+from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
@@ -19,16 +20,19 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
+class StableDiffusion1ModelFormat(str, Enum):
+ Checkpoint = "checkpoint"
+ Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
- format: Literal["diffusers"]
+ model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
- format: Literal["checkpoint"]
+ model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
- if model_format == "checkpoint":
+ if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
- elif model_format == "diffusers":
+ elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
@@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel):
return cls.create_config(
path=path,
- format=model_format,
+ model_format=model_format,
config=ckpt_config_path,
variant=variant,
@@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
- return "diffusers"
+ return StableDiffusion1ModelFormat.Diffusers
else:
- return "checkpoint"
+ return StableDiffusion1ModelFormat.Checkpoint
@classmethod
def convert_if_required(
@@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
else:
return model_path
+class StableDiffusion2ModelFormat(str, Enum):
+ Checkpoint = "checkpoint"
+ Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
- format: Literal["diffusers"]
+ model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase):
- format: Literal["checkpoint"]
+ model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
- if model_format == "checkpoint":
+ if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
- elif model_format == "diffusers":
+ elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
@@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel):
return cls.create_config(
path=path,
- format=model_format,
+ model_format=model_format,
config=ckpt_config_path,
variant=variant,
@@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
- return "diffusers"
+ return StableDiffusion2ModelFormat.Diffusers
else:
- return "checkpoint"
+ return StableDiffusion2ModelFormat.Checkpoint
@classmethod
def convert_if_required(
@@ -281,8 +288,8 @@ def _convert_ckpt_and_cache(
prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2:
- upcast_attention = config.upcast_attention
- prediction_type = config.prediction_type
+ upcast_attention = model_config.upcast_attention
+ prediction_type = model_config.prediction_type
else:
raise Exception(f"Unknown model provided: {version}")
diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py
index 66847f53eb..9a032218f0 100644
--- a/invokeai/backend/model_management/models/textual_inversion.py
+++ b/invokeai/backend/model_management/models/textual_inversion.py
@@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
- format: None
+ model_format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion
diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py
index 1edb57ccc4..76133b074d 100644
--- a/invokeai/backend/model_management/models/vae.py
+++ b/invokeai/backend/model_management/models/vae.py
@@ -1,5 +1,7 @@
import os
import torch
+import safetensors
+from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
@@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
+class VaeModelFormat(str, Enum):
+ Checkpoint = "checkpoint"
+ Diffusers = "diffusers"
+
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
class Config(ModelConfigBase):
- format: Union[Literal["checkpoint"], Literal["diffusers"]]
+ model_format: VaeModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
@@ -70,9 +76,9 @@ class VaeModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
- return "diffusers"
+ return VaeModelFormat.Diffusers
else:
- return "checkpoint"
+ return VaeModelFormat.Checkpoint
@classmethod
def convert_if_required(
@@ -82,7 +88,7 @@ class VaeModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
- if cls.detect_format(model_path) != "diffusers":
+ if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
return _convert_vae_ckpt_and_cache(
weights_path=model_path,
output_path=output_path,
diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx
index ddc6dace27..55fcc97745 100644
--- a/invokeai/frontend/web/src/app/components/App.tsx
+++ b/invokeai/frontend/web/src/app/components/App.tsx
@@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
+import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
+import { useListModelsQuery } from 'services/apiSlice';
const DEFAULT_CONFIG = {};
@@ -45,6 +47,18 @@ const App = ({
const isApplicationReady = useIsApplicationReady();
+ const { data: pipelineModels } = useListModelsQuery({
+ model_type: 'pipeline',
+ });
+ const { data: controlnetModels } = useListModelsQuery({
+ model_type: 'controlnet',
+ });
+ const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
+ const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
+ const { data: embeddingModels } = useListModelsQuery({
+ model_type: 'embedding',
+ });
+
const [loadingOverridden, setLoadingOverridden] = useState(false);
const dispatch = useAppDispatch();
@@ -143,6 +157,7 @@ const App = ({
+
>
diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx
index 0537d1de2a..141e62652d 100644
--- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx
+++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx
@@ -21,6 +21,8 @@ import {
DeleteImageContext,
DeleteImageContextProvider,
} from 'app/contexts/DeleteImageContext';
+import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
+import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@@ -76,11 +78,13 @@ const InvokeAIUI = ({
-
+
+
+
diff --git a/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx b/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx
new file mode 100644
index 0000000000..f5a856d3d8
--- /dev/null
+++ b/invokeai/frontend/web/src/app/contexts/AddImageToBoardContext.tsx
@@ -0,0 +1,89 @@
+import { useDisclosure } from '@chakra-ui/react';
+import { PropsWithChildren, createContext, useCallback, useState } from 'react';
+import { ImageDTO } from 'services/api';
+import { useAddImageToBoardMutation } from 'services/apiSlice';
+
+export type ImageUsage = {
+ isInitialImage: boolean;
+ isCanvasImage: boolean;
+ isNodesImage: boolean;
+ isControlNetImage: boolean;
+};
+
+type AddImageToBoardContextValue = {
+ /**
+ * Whether the move image dialog is open.
+ */
+ isOpen: boolean;
+ /**
+ * Closes the move image dialog.
+ */
+ onClose: () => void;
+ /**
+ * The image pending movement
+ */
+ image?: ImageDTO;
+ onClickAddToBoard: (image: ImageDTO) => void;
+ handleAddToBoard: (boardId: string) => void;
+};
+
+export const AddImageToBoardContext =
+ createContext({
+ isOpen: false,
+ onClose: () => undefined,
+ onClickAddToBoard: () => undefined,
+ handleAddToBoard: () => undefined,
+ });
+
+type Props = PropsWithChildren;
+
+export const AddImageToBoardContextProvider = (props: Props) => {
+ const [imageToMove, setImageToMove] = useState();
+ const { isOpen, onOpen, onClose } = useDisclosure();
+
+ const [addImageToBoard, result] = useAddImageToBoardMutation();
+
+ // Clean up after deleting or dismissing the modal
+ const closeAndClearImageToDelete = useCallback(() => {
+ setImageToMove(undefined);
+ onClose();
+ }, [onClose]);
+
+ const onClickAddToBoard = useCallback(
+ (image?: ImageDTO) => {
+ if (!image) {
+ return;
+ }
+ setImageToMove(image);
+ onOpen();
+ },
+ [setImageToMove, onOpen]
+ );
+
+ const handleAddToBoard = useCallback(
+ (boardId: string) => {
+ if (imageToMove) {
+ addImageToBoard({
+ board_id: boardId,
+ image_name: imageToMove.image_name,
+ });
+ closeAndClearImageToDelete();
+ }
+ },
+ [addImageToBoard, closeAndClearImageToDelete, imageToMove]
+ );
+
+ return (
+
+ {props.children}
+
+ );
+};
diff --git a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx b/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx
index 8263b48114..d01298944b 100644
--- a/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx
+++ b/invokeai/frontend/web/src/app/contexts/DeleteImageContext.tsx
@@ -35,25 +35,23 @@ export const selectImageUsage = createSelector(
(state: RootState, image_name?: string) => image_name,
],
(generation, canvas, nodes, controlNet, image_name) => {
- const isInitialImage = generation.initialImage?.image_name === image_name;
+ const isInitialImage = generation.initialImage?.imageName === image_name;
const isCanvasImage = canvas.layerState.objects.some(
- (obj) => obj.kind === 'image' && obj.image.image_name === image_name
+ (obj) => obj.kind === 'image' && obj.imageName === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
return some(
node.data.inputs,
- (input) =>
- input.type === 'image' && input.value?.image_name === image_name
+ (input) => input.type === 'image' && input.value === image_name
);
});
const isControlNetImage = some(
controlNet.controlNets,
(c) =>
- c.controlImage?.image_name === image_name ||
- c.processedControlImage?.image_name === image_name
+ c.controlImage === image_name || c.processedControlImage === image_name
);
const imageUsage: ImageUsage = {
diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts
index 5025ca081a..cb18d48301 100644
--- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts
+++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts
@@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
-import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es';
@@ -18,7 +17,6 @@ const serializationDenylist: {
gallery: galleryPersistDenylist,
generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist,
- models: modelsPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,
diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts
index c6af5f3612..8f40b0bb59 100644
--- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts
+++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/unserialize.ts
@@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice';
-import { initialModelsState } from 'features/system/store/modelSlice';
import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice';
@@ -21,7 +20,6 @@ const initialStates: {
gallery: initialGalleryState,
generation: initialGenerationState,
lightbox: initialLightboxState,
- models: initialModelsState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
system: initialSystemState,
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
index 8c073e81d6..cb641d00db 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts
@@ -73,6 +73,15 @@ import { addImageCategoriesChangedListener } from './listeners/imageCategoriesCh
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
+import {
+ addImageAddedToBoardFulfilledListener,
+ addImageAddedToBoardRejectedListener,
+} from './listeners/imageAddedToBoard';
+import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
+import {
+ addImageRemovedFromBoardFulfilledListener,
+ addImageRemovedFromBoardRejectedListener,
+} from './listeners/imageRemovedFromBoard';
export const listenerMiddleware = createListenerMiddleware();
@@ -92,6 +101,12 @@ export type AppListenerEffect = ListenerEffect<
AppDispatch
>;
+/**
+ * The RTK listener middleware is a lightweight alternative sagas/observables.
+ *
+ * Most side effect logic should live in a listener.
+ */
+
// Image uploaded
addImageUploadedFulfilledListener();
addImageUploadedRejectedListener();
@@ -183,3 +198,10 @@ addControlNetAutoProcessListener();
// Update image URLs on connect
addUpdateImageUrlsOnConnectListener();
+
+// Boards
+addImageAddedToBoardFulfilledListener();
+addImageAddedToBoardRejectedListener();
+addImageRemovedFromBoardFulfilledListener();
+addImageRemovedFromBoardRejectedListener();
+addBoardIdSelectedListener();
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts
new file mode 100644
index 0000000000..eab4389ceb
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts
@@ -0,0 +1,99 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { boardIdSelected } from 'features/gallery/store/boardSlice';
+import { selectImagesAll } from 'features/gallery/store/imagesSlice';
+import { IMAGES_PER_PAGE, receivedPageOfImages } from 'services/thunks/image';
+import { api } from 'services/apiSlice';
+import { imageSelected } from 'features/gallery/store/gallerySlice';
+
+const moduleLog = log.child({ namespace: 'boards' });
+
+export const addBoardIdSelectedListener = () => {
+ startAppListening({
+ actionCreator: boardIdSelected,
+ effect: (action, { getState, dispatch }) => {
+ const boardId = action.payload;
+
+ // we need to check if we need to fetch more images
+
+ const state = getState();
+ const allImages = selectImagesAll(state);
+
+ if (!boardId) {
+ // a board was unselected
+ dispatch(imageSelected(allImages[0]?.image_name));
+ return;
+ }
+
+ const { categories } = state.images;
+
+ const filteredImages = allImages.filter((i) => {
+ const isInCategory = categories.includes(i.image_category);
+ const isInSelectedBoard = boardId ? i.board_id === boardId : true;
+ return isInCategory && isInSelectedBoard;
+ });
+
+ // get the board from the cache
+ const { data: boards } = api.endpoints.listAllBoards.select()(state);
+ const board = boards?.find((b) => b.board_id === boardId);
+
+ if (!board) {
+ // can't find the board in cache...
+ dispatch(imageSelected(allImages[0]?.image_name));
+ return;
+ }
+
+ dispatch(imageSelected(board.cover_image_name));
+
+ // if we haven't loaded one full page of images from this board, load more
+ if (
+ filteredImages.length < board.image_count &&
+ filteredImages.length < IMAGES_PER_PAGE
+ ) {
+ dispatch(receivedPageOfImages({ categories, boardId }));
+ }
+ },
+ });
+};
+
+export const addBoardIdSelected_changeSelectedImage_listener = () => {
+ startAppListening({
+ actionCreator: boardIdSelected,
+ effect: (action, { getState, dispatch }) => {
+ const boardId = action.payload;
+
+ const state = getState();
+
+ // we need to check if we need to fetch more images
+
+ if (!boardId) {
+ // a board was unselected - we don't need to do anything
+ return;
+ }
+
+ const { categories } = state.images;
+
+ const filteredImages = selectImagesAll(state).filter((i) => {
+ const isInCategory = categories.includes(i.image_category);
+ const isInSelectedBoard = boardId ? i.board_id === boardId : true;
+ return isInCategory && isInSelectedBoard;
+ });
+
+ // get the board from the cache
+ const { data: boards } = api.endpoints.listAllBoards.select()(state);
+ const board = boards?.find((b) => b.board_id === boardId);
+ if (!board) {
+ // can't find the board in cache...
+ return;
+ }
+
+ // if we haven't loaded one full page of images from this board, load more
+ if (
+ filteredImages.length < board.image_count &&
+ filteredImages.length < IMAGES_PER_PAGE
+ ) {
+ dispatch(receivedPageOfImages({ categories, boardId }));
+ }
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts
index ce1b515b84..7ff9a5118c 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts
@@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
is_intermediate: true,
- image: pick(controlNet.controlImage, ['image_name']),
+ image: { image_name: controlNet.controlImage },
},
},
};
@@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => {
dispatch(
controlNetProcessedImageChanged({
controlNetId,
- processedControlImage,
+ processedControlImage: processedControlImage.image_name,
})
);
}
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard.ts
new file mode 100644
index 0000000000..0f404cab68
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard.ts
@@ -0,0 +1,40 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { imageMetadataReceived } from 'services/thunks/image';
+import { api } from 'services/apiSlice';
+
+const moduleLog = log.child({ namespace: 'boards' });
+
+export const addImageAddedToBoardFulfilledListener = () => {
+ startAppListening({
+ matcher: api.endpoints.addImageToBoard.matchFulfilled,
+ effect: (action, { getState, dispatch }) => {
+ const { board_id, image_name } = action.meta.arg.originalArgs;
+
+ moduleLog.debug(
+ { data: { board_id, image_name } },
+ 'Image added to board'
+ );
+
+ dispatch(
+ imageMetadataReceived({
+ imageName: image_name,
+ })
+ );
+ },
+ });
+};
+
+export const addImageAddedToBoardRejectedListener = () => {
+ startAppListening({
+ matcher: api.endpoints.addImageToBoard.matchRejected,
+ effect: (action, { getState, dispatch }) => {
+ const { board_id, image_name } = action.meta.arg.originalArgs;
+
+ moduleLog.debug(
+ { data: { board_id, image_name } },
+ 'Problem adding image to board'
+ );
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts
index 85d56d3913..8f01b8d7b8 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageCategoriesChanged.ts
@@ -12,12 +12,16 @@ export const addImageCategoriesChangedListener = () => {
startAppListening({
actionCreator: imageCategoriesChanged,
effect: (action, { getState, dispatch }) => {
- const filteredImagesCount = selectFilteredImagesAsArray(
- getState()
- ).length;
+ const state = getState();
+ const filteredImagesCount = selectFilteredImagesAsArray(state).length;
if (!filteredImagesCount) {
- dispatch(receivedPageOfImages());
+ dispatch(
+ receivedPageOfImages({
+ categories: action.payload,
+ boardId: state.boards.selectedBoardId,
+ })
+ );
}
},
});
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts
index 4c0c057242..224aa0d2aa 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts
@@ -6,15 +6,15 @@ import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import {
imageRemoved,
- selectImagesEntities,
selectImagesIds,
} from 'features/gallery/store/imagesSlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
+import { api } from 'services/apiSlice';
-const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
+const moduleLog = log.child({ namespace: 'image' });
/**
* Called when the user requests an image deletion
@@ -22,7 +22,7 @@ const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
export const addRequestedImageDeletionListener = () => {
startAppListening({
actionCreator: requestedImageDeletion,
- effect: (action, { dispatch, getState }) => {
+ effect: async (action, { dispatch, getState, condition }) => {
const { image, imageUsage } = action.payload;
const { image_name } = image;
@@ -30,9 +30,8 @@ export const addRequestedImageDeletionListener = () => {
const state = getState();
const selectedImage = state.gallery.selectedImage;
- if (selectedImage && selectedImage.image_name === image_name) {
+ if (selectedImage === image_name) {
const ids = selectImagesIds(state);
- const entities = selectImagesEntities(state);
const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name
@@ -48,10 +47,8 @@ export const addRequestedImageDeletionListener = () => {
const newSelectedImageId = filteredIds[newSelectedImageIndex];
- const newSelectedImage = entities[newSelectedImageId];
-
if (newSelectedImageId) {
- dispatch(imageSelected(newSelectedImage));
+ dispatch(imageSelected(newSelectedImageId as string));
} else {
dispatch(imageSelected());
}
@@ -79,7 +76,21 @@ export const addRequestedImageDeletionListener = () => {
dispatch(imageRemoved(image_name));
// Delete from server
- dispatch(imageDeleted({ imageName: image_name }));
+ const { requestId } = dispatch(imageDeleted({ imageName: image_name }));
+
+ // Wait for successful deletion, then trigger boards to re-fetch
+ const wasImageDeleted = await condition(
+ (action): action is ReturnType =>
+ imageDeleted.fulfilled.match(action) &&
+ action.meta.requestId === requestId,
+ 30000
+ );
+
+ if (wasImageDeleted) {
+ dispatch(
+ api.util.invalidateTags([{ type: 'Board', id: image.board_id }])
+ );
+ }
},
});
};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard.ts
new file mode 100644
index 0000000000..40847ade3a
--- /dev/null
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard.ts
@@ -0,0 +1,40 @@
+import { log } from 'app/logging/useLogger';
+import { startAppListening } from '..';
+import { imageMetadataReceived } from 'services/thunks/image';
+import { api } from 'services/apiSlice';
+
+const moduleLog = log.child({ namespace: 'boards' });
+
+export const addImageRemovedFromBoardFulfilledListener = () => {
+ startAppListening({
+ matcher: api.endpoints.removeImageFromBoard.matchFulfilled,
+ effect: (action, { getState, dispatch }) => {
+ const { board_id, image_name } = action.meta.arg.originalArgs;
+
+ moduleLog.debug(
+ { data: { board_id, image_name } },
+ 'Image added to board'
+ );
+
+ dispatch(
+ imageMetadataReceived({
+ imageName: image_name,
+ })
+ );
+ },
+ });
+};
+
+export const addImageRemovedFromBoardRejectedListener = () => {
+ startAppListening({
+ matcher: api.endpoints.removeImageFromBoard.matchRejected,
+ effect: (action, { getState, dispatch }) => {
+ const { board_id, image_name } = action.meta.arg.originalArgs;
+
+ moduleLog.debug(
+ { data: { board_id, image_name } },
+ 'Problem adding image to board'
+ );
+ },
+ });
+};
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts
index 40ed062353..fc44d206c8 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts
@@ -46,7 +46,12 @@ export const addImageUploadedFulfilledListener = () => {
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
const { controlNetId } = postUploadAction;
- dispatch(controlNetImageChanged({ controlNetId, controlImage: image }));
+ dispatch(
+ controlNetImageChanged({
+ controlNetId,
+ controlImage: image.image_name,
+ })
+ );
return;
}
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts
index 3049d2c933..bf54e63836 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts
@@ -1,9 +1,8 @@
-import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/thunks/image';
-import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema';
+import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
@@ -15,16 +14,17 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected');
- const { models, nodes, config, images } = getState();
+ const { nodes, config, images } = getState();
const { disabledTabs } = config;
if (!images.ids.length) {
- dispatch(receivedPageOfImages());
- }
-
- if (!models.ids.length) {
- dispatch(receivedModels());
+ dispatch(
+ receivedPageOfImages({
+ categories: ['general'],
+ isIntermediate: false,
+ })
+ );
}
if (!nodes.schema && !disabledTabs.includes('nodes')) {
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
index c9ab894ddb..c204f0bdfb 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts
@@ -9,6 +9,7 @@ import { imageMetadataReceived } from 'services/thunks/image';
import { sessionCanceled } from 'services/thunks/session';
import { isImageOutput } from 'services/types/guards';
import { progressImageSet } from 'features/system/store/systemSlice';
+import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image'];
@@ -24,7 +25,8 @@ export const addInvocationCompleteEventListener = () => {
const sessionId = action.payload.data.graph_execution_state_id;
- const { cancelType, isCancelScheduled } = getState().system;
+ const { cancelType, isCancelScheduled, boardIdToAddTo } =
+ getState().system;
// Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
@@ -57,6 +59,15 @@ export const addInvocationCompleteEventListener = () => {
dispatch(addImageToStagingArea(imageDTO));
}
+ if (boardIdToAddTo && !imageDTO.is_intermediate) {
+ dispatch(
+ api.endpoints.addImageToBoard.initiate({
+ board_id: boardIdToAddTo,
+ image_name,
+ })
+ );
+ }
+
dispatch(progressImageSet(null));
}
// pass along the socket event as an application action
diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts
index 7cb8012848..b9ddcea4c3 100644
--- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts
@@ -22,15 +22,15 @@ const selectAllUsedImages = createSelector(
selectImagesEntities,
],
(generation, canvas, nodes, controlNet, imageEntities) => {
- const allUsedImages: ImageDTO[] = [];
+ const allUsedImages: string[] = [];
if (generation.initialImage) {
- allUsedImages.push(generation.initialImage);
+ allUsedImages.push(generation.initialImage.imageName);
}
canvas.layerState.objects.forEach((obj) => {
if (obj.kind === 'image') {
- allUsedImages.push(obj.image);
+ allUsedImages.push(obj.imageName);
}
});
@@ -53,7 +53,7 @@ const selectAllUsedImages = createSelector(
forEach(imageEntities, (image) => {
if (image) {
- allUsedImages.push(image);
+ allUsedImages.push(image.image_name);
}
});
@@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => {
`Fetching new image URLs for ${allUsedImages.length} images`
);
- allUsedImages.forEach(({ image_name }) => {
+ allUsedImages.forEach((image_name) => {
dispatch(
imageUrlsReceived({
imageName: image_name,
diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts
index f577b73895..57a97168a3 100644
--- a/invokeai/frontend/web/src/app/store/store.ts
+++ b/invokeai/frontend/web/src/app/store/store.ts
@@ -5,40 +5,39 @@ import {
configureStore,
} from '@reduxjs/toolkit';
-import { rememberReducer, rememberEnhancer } from 'redux-remember';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
+import { rememberEnhancer, rememberReducer } from 'redux-remember';
import canvasReducer from 'features/canvas/store/canvasSlice';
+import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import imagesReducer from 'features/gallery/store/imagesSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice';
-import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice';
-import configReducer from 'features/system/store/configSlice';
-import uiReducer from 'features/ui/store/uiSlice';
-import hotkeysReducer from 'features/ui/store/hotkeysSlice';
-import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
+import boardsReducer from 'features/gallery/store/boardSlice';
+import configReducer from 'features/system/store/configSlice';
+import hotkeysReducer from 'features/ui/store/hotkeysSlice';
+import uiReducer from 'features/ui/store/uiSlice';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
-import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
-
+import { stateSanitizer } from './middleware/devtools/stateSanitizer';
+import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
-import { LOCALSTORAGE_PREFIX } from './constants';
+import { api } from 'services/apiSlice';
const allReducers = {
canvas: canvasReducer,
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
- models: modelsReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
system: systemReducer,
@@ -47,7 +46,9 @@ const allReducers = {
hotkeys: hotkeysReducer,
images: imagesReducer,
controlNet: controlNetReducer,
+ boards: boardsReducer,
// session: sessionReducer,
+ [api.reducerPath]: api.reducer,
};
const rootReducer = combineReducers(allReducers);
@@ -59,12 +60,12 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'gallery',
'generation',
'lightbox',
- // 'models',
'nodes',
'postprocessing',
'system',
'ui',
'controlNet',
+ // 'boards',
// 'hotkeys',
// 'config',
];
@@ -84,6 +85,7 @@ export const store = configureStore({
immutableCheck: false,
serializableCheck: false,
})
+ .concat(api.middleware)
.concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware),
devTools: {
diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx
index 669a68c88a..e54b4a8872 100644
--- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx
+++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx
@@ -9,7 +9,7 @@ import {
import { useDraggable, useDroppable } from '@dnd-kit/core';
import { useCombinedRefs } from '@dnd-kit/utilities';
import IAIIconButton from 'common/components/IAIIconButton';
-import { IAIImageFallback } from 'common/components/IAIImageFallback';
+import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { AnimatePresence } from 'framer-motion';
import { ReactElement, SyntheticEvent, useCallback } from 'react';
@@ -53,7 +53,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
isDropDisabled = false,
isDragDisabled = false,
isUploadDisabled = false,
- fallback = ,
+ fallback = ,
payloadImage,
minSize = 24,
postUploadAction,
diff --git a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx
index 3d34fbca9e..03a00d5b1c 100644
--- a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx
+++ b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx
@@ -1,10 +1,20 @@
-import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react';
+import {
+ As,
+ Flex,
+ FlexProps,
+ Icon,
+ IconProps,
+ Spinner,
+ SpinnerProps,
+} from '@chakra-ui/react';
+import { ReactElement } from 'react';
+import { FaImage } from 'react-icons/fa';
type Props = FlexProps & {
spinnerProps?: SpinnerProps;
};
-export const IAIImageFallback = (props: Props) => {
+export const IAIImageLoadingFallback = (props: Props) => {
const { spinnerProps, ...rest } = props;
const { sx, ...restFlexProps } = rest;
return (
@@ -25,3 +35,35 @@ export const IAIImageFallback = (props: Props) => {
);
};
+
+type IAINoImageFallbackProps = {
+ flexProps?: FlexProps;
+ iconProps?: IconProps;
+ as?: As;
+};
+
+export const IAINoImageFallback = (props: IAINoImageFallbackProps) => {
+ const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} };
+ const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} };
+ return (
+
+
+
+ );
+};
diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx
index b8757eff0c..c3132f0285 100644
--- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx
+++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx
@@ -1,14 +1,21 @@
-import { Image } from 'react-konva';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
+import { Image, Rect } from 'react-konva';
+import { useGetImageDTOQuery } from 'services/apiSlice';
import useImage from 'use-image';
+import { CanvasImage } from '../store/canvasTypes';
type IAICanvasImageProps = {
- url: string;
- x: number;
- y: number;
+ canvasImage: CanvasImage;
};
const IAICanvasImage = (props: IAICanvasImageProps) => {
- const { url, x, y } = props;
- const [image] = useImage(url, 'anonymous');
+ const { width, height, x, y, imageName } = props.canvasImage;
+ const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
+ const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
+
+ if (!imageDTO) {
+ return ;
+ }
+
return ;
};
diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx
index ea04aa95c8..ec1e87cca7 100644
--- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx
+++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasObjectRenderer.tsx
@@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => {
{objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) {
- return (
-
- );
+ return ;
} else if (isCanvasBaseLine(obj)) {
const line = (
{
return (
{shouldShowStagingImage && currentStagingAreaImage && (
-
+
)}
{shouldShowStagingOutline && (
diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts
index b7092bf7e0..3e40c1211d 100644
--- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts
+++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts
@@ -203,7 +203,7 @@ export const canvasSlice = createSlice({
y: 0,
width: width,
height: height,
- image: image,
+ imageName: image.image_name,
},
],
};
@@ -325,7 +325,7 @@ export const canvasSlice = createSlice({
kind: 'image',
layer: 'base',
...state.layerState.stagingArea.boundingBox,
- image,
+ imageName: image.image_name,
});
state.layerState.stagingArea.selectedImageIndex =
@@ -865,25 +865,25 @@ export const canvasSlice = createSlice({
state.doesCanvasNeedScaling = true;
});
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_url, thumbnail_url } = action.payload;
+ // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
+ // const { image_name, image_url, thumbnail_url } = action.payload;
- state.layerState.objects.forEach((object) => {
- if (object.kind === 'image') {
- if (object.image.image_name === image_name) {
- object.image.image_url = image_url;
- object.image.thumbnail_url = thumbnail_url;
- }
- }
- });
+ // state.layerState.objects.forEach((object) => {
+ // if (object.kind === 'image') {
+ // if (object.image.image_name === image_name) {
+ // object.image.image_url = image_url;
+ // object.image.thumbnail_url = thumbnail_url;
+ // }
+ // }
+ // });
- state.layerState.stagingArea.images.forEach((stagedImage) => {
- if (stagedImage.image.image_name === image_name) {
- stagedImage.image.image_url = image_url;
- stagedImage.image.thumbnail_url = thumbnail_url;
- }
- });
- });
+ // state.layerState.stagingArea.images.forEach((stagedImage) => {
+ // if (stagedImage.image.image_name === image_name) {
+ // stagedImage.image.image_url = image_url;
+ // stagedImage.image.thumbnail_url = thumbnail_url;
+ // }
+ // });
+ // });
},
});
diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts
index ae78287a7b..9294e10d32 100644
--- a/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts
+++ b/invokeai/frontend/web/src/features/canvas/store/canvasTypes.ts
@@ -38,7 +38,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
- image: ImageDTO;
+ imageName: string;
};
export type CanvasMaskLine = {
diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx
index b8d8896dad..217caf9461 100644
--- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx
@@ -11,9 +11,11 @@ import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { AnimatePresence, motion } from 'framer-motion';
-import { IAIImageFallback } from 'common/components/IAIImageFallback';
+import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa';
+import { useGetImageDTOQuery } from 'services/apiSlice';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector(
controlNetSelector,
@@ -31,24 +33,45 @@ type Props = {
const ControlNetImagePreview = (props: Props) => {
const { imageSx } = props;
- const { controlNetId, controlImage, processedControlImage, processorType } =
- props.controlNet;
+ const {
+ controlNetId,
+ controlImage: controlImageName,
+ processedControlImage: processedControlImageName,
+ processorType,
+ } = props.controlNet;
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
+ const {
+ data: controlImage,
+ isLoading: isLoadingControlImage,
+ isError: isErrorControlImage,
+ isSuccess: isSuccessControlImage,
+ } = useGetImageDTOQuery(controlImageName ?? skipToken);
+
+ const {
+ data: processedControlImage,
+ isLoading: isLoadingProcessedControlImage,
+ isError: isErrorProcessedControlImage,
+ isSuccess: isSuccessProcessedControlImage,
+ } = useGetImageDTOQuery(processedControlImageName ?? skipToken);
+
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
- if (controlImage?.image_name === droppedImage.image_name) {
+ if (controlImageName === droppedImage.image_name) {
return;
}
setIsMouseOverImage(false);
dispatch(
- controlNetImageChanged({ controlNetId, controlImage: droppedImage })
+ controlNetImageChanged({
+ controlNetId,
+ controlImage: droppedImage.image_name,
+ })
);
},
- [controlImage, controlNetId, dispatch]
+ [controlImageName, controlNetId, dispatch]
);
const handleResetControlImage = useCallback(() => {
@@ -150,7 +173,7 @@ const ControlNetImagePreview = (props: Props) => {
h: 'full',
}}
>
-
+
)}
{controlImage && (
diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
index f1b62cd997..5a54bdcd74 100644
--- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
+++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
@@ -39,8 +39,8 @@ export type ControlNetConfig = {
weight: number;
beginStepPct: number;
endStepPct: number;
- controlImage: ImageDTO | null;
- processedControlImage: ImageDTO | null;
+ controlImage: string | null;
+ processedControlImage: string | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean;
@@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({
},
controlNetAddedFromImage: (
state,
- action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
+ action: PayloadAction<{ controlNetId: string; controlImage: string }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = {
@@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
- controlImage: ImageDTO | null;
+ controlImage: string | null;
}>
) => {
const { controlNetId, controlImage } = action.payload;
@@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
- processedControlImage: ImageDTO | null;
+ processedControlImage: string | null;
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
@@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({
// Preemptively remove the image from the gallery
const { imageName } = action.meta.arg;
forEach(state.controlNets, (c) => {
- if (c.controlImage?.image_name === imageName) {
+ if (c.controlImage === imageName) {
c.controlImage = null;
c.processedControlImage = null;
}
- if (c.processedControlImage?.image_name === imageName) {
+ if (c.processedControlImage === imageName) {
c.processedControlImage = null;
}
});
});
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_url, thumbnail_url } = action.payload;
+ // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
+ // const { image_name, image_url, thumbnail_url } = action.payload;
- forEach(state.controlNets, (c) => {
- if (c.controlImage?.image_name === image_name) {
- c.controlImage.image_url = image_url;
- c.controlImage.thumbnail_url = thumbnail_url;
- }
- if (c.processedControlImage?.image_name === image_name) {
- c.processedControlImage.image_url = image_url;
- c.processedControlImage.thumbnail_url = thumbnail_url;
- }
- });
- });
+ // forEach(state.controlNets, (c) => {
+ // if (c.controlImage?.image_name === image_name) {
+ // c.controlImage.image_url = image_url;
+ // c.controlImage.thumbnail_url = thumbnail_url;
+ // }
+ // if (c.processedControlImage?.image_name === image_name) {
+ // c.processedControlImage.image_url = image_url;
+ // c.processedControlImage.thumbnail_url = thumbnail_url;
+ // }
+ // });
+ // });
builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = [];
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/AddBoardButton.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/AddBoardButton.tsx
new file mode 100644
index 0000000000..632cebcb33
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/AddBoardButton.tsx
@@ -0,0 +1,27 @@
+import IAIButton from 'common/components/IAIButton';
+import { useCallback } from 'react';
+import { useCreateBoardMutation } from 'services/apiSlice';
+
+const DEFAULT_BOARD_NAME = 'My Board';
+
+const AddBoardButton = () => {
+ const [createBoard, { isLoading }] = useCreateBoardMutation();
+
+ const handleCreateBoard = useCallback(() => {
+ createBoard(DEFAULT_BOARD_NAME);
+ }, [createBoard]);
+
+ return (
+
+ Add Board
+
+ );
+};
+
+export default AddBoardButton;
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx
new file mode 100644
index 0000000000..e506c88e2d
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/AllImagesBoard.tsx
@@ -0,0 +1,93 @@
+import { Flex, Text } from '@chakra-ui/react';
+import { FaImages } from 'react-icons/fa';
+import { boardIdSelected } from '../../store/boardSlice';
+import { useDispatch } from 'react-redux';
+import { IAINoImageFallback } from 'common/components/IAIImageFallback';
+import { AnimatePresence } from 'framer-motion';
+import { SelectedItemOverlay } from '../SelectedItemOverlay';
+import { useCallback } from 'react';
+import { ImageDTO } from 'services/api';
+import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
+import { useDroppable } from '@dnd-kit/core';
+import IAIDropOverlay from 'common/components/IAIDropOverlay';
+
+const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
+ const dispatch = useDispatch();
+
+ const handleAllImagesBoardClick = () => {
+ dispatch(boardIdSelected());
+ };
+
+ const [removeImageFromBoard, { isLoading }] =
+ useRemoveImageFromBoardMutation();
+
+ const handleDrop = useCallback(
+ (droppedImage: ImageDTO) => {
+ if (!droppedImage.board_id) {
+ return;
+ }
+ removeImageFromBoard({
+ board_id: droppedImage.board_id,
+ image_name: droppedImage.image_name,
+ });
+ },
+ [removeImageFromBoard]
+ );
+
+ const {
+ isOver,
+ setNodeRef,
+ active: isDropActive,
+ } = useDroppable({
+ id: `board_droppable_all_images`,
+ data: {
+ handleDrop,
+ },
+ });
+
+ return (
+
+
+
+
+ {isSelected && }
+
+
+ {isDropActive && }
+
+
+
+ All Images
+
+
+ );
+};
+
+export default AllImagesBoard;
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx
new file mode 100644
index 0000000000..738693a278
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList.tsx
@@ -0,0 +1,134 @@
+import {
+ Collapse,
+ Flex,
+ Grid,
+ IconButton,
+ Input,
+ InputGroup,
+ InputRightElement,
+} from '@chakra-ui/react';
+import { createSelector } from '@reduxjs/toolkit';
+import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
+import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
+import {
+ boardsSelector,
+ setBoardSearchText,
+} from 'features/gallery/store/boardSlice';
+import { memo, useState } from 'react';
+import HoverableBoard from './HoverableBoard';
+import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
+import AddBoardButton from './AddBoardButton';
+import AllImagesBoard from './AllImagesBoard';
+import { CloseIcon } from '@chakra-ui/icons';
+import { useListAllBoardsQuery } from 'services/apiSlice';
+
+const selector = createSelector(
+ [boardsSelector],
+ (boardsState) => {
+ const { selectedBoardId, searchText } = boardsState;
+ return { selectedBoardId, searchText };
+ },
+ defaultSelectorOptions
+);
+
+type Props = {
+ isOpen: boolean;
+};
+
+const BoardsList = (props: Props) => {
+ const { isOpen } = props;
+ const dispatch = useAppDispatch();
+ const { selectedBoardId, searchText } = useAppSelector(selector);
+
+ const { data: boards } = useListAllBoardsQuery();
+
+ const filteredBoards = searchText
+ ? boards?.filter((board) =>
+ board.board_name.toLowerCase().includes(searchText.toLowerCase())
+ )
+ : boards;
+
+ const [searchMode, setSearchMode] = useState(false);
+
+ const handleBoardSearch = (searchTerm: string) => {
+ setSearchMode(searchTerm.length > 0);
+ dispatch(setBoardSearchText(searchTerm));
+ };
+ const clearBoardSearch = () => {
+ setSearchMode(false);
+ dispatch(setBoardSearchText(''));
+ };
+
+ return (
+
+
+
+
+ {
+ handleBoardSearch(e.target.value);
+ }}
+ />
+ {searchText && searchText.length && (
+
+ }
+ />
+
+ )}
+
+
+
+
+
+ {!searchMode && }
+ {filteredBoards &&
+ filteredBoards.map((board) => (
+
+ ))}
+
+
+
+
+ );
+};
+
+export default memo(BoardsList);
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx
new file mode 100644
index 0000000000..a2c07e4870
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/HoverableBoard.tsx
@@ -0,0 +1,193 @@
+import {
+ Badge,
+ Box,
+ Editable,
+ EditableInput,
+ EditablePreview,
+ Flex,
+ Image,
+ MenuItem,
+ MenuList,
+} from '@chakra-ui/react';
+
+import { useAppDispatch } from 'app/store/storeHooks';
+import { memo, useCallback } from 'react';
+import { FaFolder, FaTrash } from 'react-icons/fa';
+import { ContextMenu } from 'chakra-ui-contextmenu';
+import { BoardDTO, ImageDTO } from 'services/api';
+import { IAINoImageFallback } from 'common/components/IAIImageFallback';
+import { boardIdSelected } from 'features/gallery/store/boardSlice';
+import {
+ useAddImageToBoardMutation,
+ useDeleteBoardMutation,
+ useGetImageDTOQuery,
+ useUpdateBoardMutation,
+} from 'services/apiSlice';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
+import { useDroppable } from '@dnd-kit/core';
+import { AnimatePresence } from 'framer-motion';
+import IAIDropOverlay from 'common/components/IAIDropOverlay';
+import { SelectedItemOverlay } from '../SelectedItemOverlay';
+
+interface HoverableBoardProps {
+ board: BoardDTO;
+ isSelected: boolean;
+}
+
+const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
+ const dispatch = useAppDispatch();
+
+ const { data: coverImage } = useGetImageDTOQuery(
+ board.cover_image_name ?? skipToken
+ );
+
+ const { board_name, board_id } = board;
+
+ const handleSelectBoard = useCallback(() => {
+ dispatch(boardIdSelected(board_id));
+ }, [board_id, dispatch]);
+
+ const [updateBoard, { isLoading: isUpdateBoardLoading }] =
+ useUpdateBoardMutation();
+
+ const [deleteBoard, { isLoading: isDeleteBoardLoading }] =
+ useDeleteBoardMutation();
+
+ const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] =
+ useAddImageToBoardMutation();
+
+ const handleUpdateBoardName = (newBoardName: string) => {
+ updateBoard({ board_id, changes: { board_name: newBoardName } });
+ };
+
+ const handleDeleteBoard = useCallback(() => {
+ deleteBoard(board_id);
+ }, [board_id, deleteBoard]);
+
+ const handleDrop = useCallback(
+ (droppedImage: ImageDTO) => {
+ if (droppedImage.board_id === board_id) {
+ return;
+ }
+ addImageToBoard({ board_id, image_name: droppedImage.image_name });
+ },
+ [addImageToBoard, board_id]
+ );
+
+ const {
+ isOver,
+ setNodeRef,
+ active: isDropActive,
+ } = useDroppable({
+ id: `board_droppable_${board_id}`,
+ data: {
+ handleDrop,
+ },
+ });
+
+ return (
+
+
+ menuProps={{ size: 'sm', isLazy: true }}
+ renderMenu={() => (
+
+ }
+ onClickCapture={handleDeleteBoard}
+ >
+ Delete Board
+
+
+ )}
+ >
+ {(ref) => (
+
+
+ {board.cover_image_name && coverImage?.image_url && (
+
+ )}
+ {!(board.cover_image_name && coverImage?.image_url) && (
+
+ )}
+
+ {board.image_count}
+
+
+ {isSelected && }
+
+
+ {isDropActive && }
+
+
+
+
+ {
+ handleUpdateBoardName(nextValue);
+ }}
+ >
+
+
+
+
+
+ )}
+
+
+ );
+});
+
+HoverableBoard.displayName = 'HoverableBoard';
+
+export default HoverableBoard;
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx
new file mode 100644
index 0000000000..b16bddd6b4
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/UpdateImageBoardModal.tsx
@@ -0,0 +1,93 @@
+import {
+ AlertDialog,
+ AlertDialogBody,
+ AlertDialogContent,
+ AlertDialogFooter,
+ AlertDialogHeader,
+ AlertDialogOverlay,
+ Box,
+ Flex,
+ Spinner,
+ Text,
+} from '@chakra-ui/react';
+import IAIButton from 'common/components/IAIButton';
+
+import { memo, useContext, useRef, useState } from 'react';
+import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
+import IAIMantineSelect from 'common/components/IAIMantineSelect';
+import { useListAllBoardsQuery } from 'services/apiSlice';
+
+const UpdateImageBoardModal = () => {
+ // const boards = useSelector(selectBoardsAll);
+ const { data: boards, isFetching } = useListAllBoardsQuery();
+ const { isOpen, onClose, handleAddToBoard, image } = useContext(
+ AddImageToBoardContext
+ );
+ const [selectedBoard, setSelectedBoard] = useState();
+
+ const cancelRef = useRef(null);
+
+ const currentBoard = boards?.find(
+ (board) => board.board_id === image?.board_id
+ );
+
+ return (
+
+
+
+
+ {currentBoard ? 'Move Image to Board' : 'Add Image to Board'}
+
+
+
+
+
+ {currentBoard && (
+
+ Moving this image from{' '}
+ {currentBoard.board_name} to
+
+ )}
+ {isFetching ? (
+
+ ) : (
+ setSelectedBoard(v)}
+ value={selectedBoard}
+ data={(boards ?? []).map((board) => ({
+ label: board.board_name,
+ value: board.board_id,
+ }))}
+ />
+ )}
+
+
+
+
+ Cancel
+ {
+ if (selectedBoard) {
+ handleAddToBoard(selectedBoard);
+ }
+ }}
+ ml={3}
+ >
+ {currentBoard ? 'Move' : 'Add'}
+
+
+
+
+
+ );
+};
+
+export default memo(UpdateImageBoardModal);
diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx
index a5eaeb4c71..169a965be0 100644
--- a/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImageButtons.tsx
@@ -51,9 +51,12 @@ import { useAppToaster } from 'app/components/Toaster';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
import { DeleteImageButton } from './DeleteImageModal';
+import { selectImagesById } from '../store/imagesSlice';
+import { RootState } from 'app/store/store';
const currentImageButtonsSelector = createSelector(
[
+ (state: RootState) => state,
systemSelector,
gallerySelector,
postprocessingSelector,
@@ -61,7 +64,7 @@ const currentImageButtonsSelector = createSelector(
lightboxSelector,
activeTabNameSelector,
],
- (system, gallery, postprocessing, ui, lightbox, activeTabName) => {
+ (state, system, gallery, postprocessing, ui, lightbox, activeTabName) => {
const {
isProcessing,
isConnected,
@@ -81,6 +84,8 @@ const currentImageButtonsSelector = createSelector(
shouldShowProgressInViewer,
} = ui;
+ const imageDTO = selectImagesById(state, gallery.selectedImage ?? '');
+
const { selectedImage } = gallery;
return {
@@ -97,10 +102,10 @@ const currentImageButtonsSelector = createSelector(
activeTabName,
isLightboxOpen,
shouldHidePreview,
- image: selectedImage,
- seed: selectedImage?.metadata?.seed,
- prompt: selectedImage?.metadata?.positive_conditioning,
- negativePrompt: selectedImage?.metadata?.negative_conditioning,
+ image: imageDTO,
+ seed: imageDTO?.metadata?.seed,
+ prompt: imageDTO?.metadata?.positive_conditioning,
+ negativePrompt: imageDTO?.metadata?.negative_conditioning,
shouldShowProgressInViewer,
};
},
diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx
index c591206a27..5426fee3b1 100644
--- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx
@@ -9,12 +9,12 @@ import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
import { memo, useCallback } from 'react';
import { systemSelector } from 'features/system/store/systemSelectors';
-import { configSelector } from '../../system/store/configSelectors';
-import { useAppToaster } from 'app/components/Toaster';
import { imageSelected } from '../store/gallerySlice';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
-import { IAIImageFallback } from 'common/components/IAIImageFallback';
+import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
+import { useGetImageDTOQuery } from 'services/apiSlice';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
export const imagesSelector = createSelector(
[uiSelector, gallerySelector, systemSelector],
@@ -29,7 +29,7 @@ export const imagesSelector = createSelector(
return {
shouldShowImageDetails,
shouldHidePreview,
- image: selectedImage,
+ selectedImage,
progressImage,
shouldShowProgressInViewer,
shouldAntialiasProgressImage,
@@ -45,11 +45,23 @@ export const imagesSelector = createSelector(
const CurrentImagePreview = () => {
const {
shouldShowImageDetails,
- image,
+ selectedImage,
progressImage,
shouldShowProgressInViewer,
shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector);
+
+ // const image = useAppSelector((state: RootState) =>
+ // selectImagesById(state, selectedImage ?? '')
+ // );
+
+ const {
+ data: image,
+ isLoading,
+ isError,
+ isSuccess,
+ } = useGetImageDTOQuery(selectedImage ?? skipToken);
+
const dispatch = useAppDispatch();
const handleDrop = useCallback(
@@ -57,7 +69,7 @@ const CurrentImagePreview = () => {
if (droppedImage.image_name === image?.image_name) {
return;
}
- dispatch(imageSelected(droppedImage));
+ dispatch(imageSelected(droppedImage.image_name));
},
[dispatch, image?.image_name]
);
@@ -98,14 +110,14 @@ const CurrentImagePreview = () => {
}}
>
}
+ fallback={}
isUploadDisabled={true}
/>
)}
- {shouldShowImageDetails && image && (
+ {shouldShowImageDetails && image && selectedImage && (
{
)}
- {!shouldShowImageDetails && image && (
+ {!shouldShowImageDetails && image && selectedImage && (
- prev.image.image_name === next.image.image_name &&
- prev.isSelected === next.isSelected;
-
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
-const HoverableImage = memo((props: HoverableImageProps) => {
+const HoverableImage = (props: HoverableImageProps) => {
const dispatch = useAppDispatch();
const {
activeTabName,
@@ -93,6 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onDelete } = useContext(DeleteImageContext);
+ const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const handleDelete = useCallback(() => {
onDelete(image);
}, [image, onDelete]);
@@ -106,11 +109,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
},
});
+ const [removeFromBoard] = useRemoveImageFromBoardMutation();
+
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleSelectImage = useCallback(() => {
- dispatch(imageSelected(image));
+ dispatch(imageSelected(image.image_name));
}, [image, dispatch]);
// Recall parameters handlers
@@ -168,6 +173,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// dispatch(setIsLightboxOpen(true));
};
+ const handleAddToBoard = useCallback(() => {
+ onClickAddToBoard(image);
+ }, [image, onClickAddToBoard]);
+
+ const handleRemoveFromBoard = useCallback(() => {
+ if (!image.board_id) {
+ return;
+ }
+ removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
+ }, [image.board_id, image.image_name, removeFromBoard]);
+
const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank');
};
@@ -244,6 +260,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{t('parameters.sendToUnifiedCanvas')}
)}
+ } onClickCapture={handleAddToBoard}>
+ {image.board_id ? 'Change Board' : 'Add to Board'}
+
+ {image.board_id && (
+ }
+ onClickCapture={handleRemoveFromBoard}
+ >
+ Remove from Board
+
+ )}
}
@@ -339,8 +366,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
);
-}, memoEqualityCheck);
+};
-HoverableImage.displayName = 'HoverableImage';
-
-export default HoverableImage;
+export default memo(HoverableImage);
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx
index fe8690e379..46f2378ae0 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx
@@ -1,12 +1,15 @@
import {
Box,
+ Button,
ButtonGroup,
Flex,
FlexProps,
Grid,
Icon,
Text,
+ VStack,
forwardRef,
+ useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
@@ -20,6 +23,7 @@ import {
setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn,
+ setGalleryView,
} from 'features/gallery/store/gallerySlice';
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
@@ -53,41 +57,51 @@ import {
selectImagesAll,
} from '../store/imagesSlice';
import { receivedPageOfImages } from 'services/thunks/image';
+import BoardsList from './Boards/BoardsList';
+import { boardsSelector } from '../store/boardSlice';
+import { ChevronUpIcon } from '@chakra-ui/icons';
+import { useListAllBoardsQuery } from 'services/apiSlice';
-const categorySelector = createSelector(
+const itemSelector = createSelector(
[(state: RootState) => state],
(state) => {
- const { images } = state;
- const { categories } = images;
+ const { categories, total: allImagesTotal, isLoading } = state.images;
+ const { selectedBoardId } = state.boards;
const allImages = selectImagesAll(state);
- const filteredImages = allImages.filter((i) =>
- categories.includes(i.image_category)
- );
+
+ const images = allImages.filter((i) => {
+ const isInCategory = categories.includes(i.image_category);
+ const isInSelectedBoard = selectedBoardId
+ ? i.board_id === selectedBoardId
+ : true;
+ return isInCategory && isInSelectedBoard;
+ });
return {
- images: filteredImages,
- isLoading: images.isLoading,
- areMoreImagesAvailable: filteredImages.length < images.total,
- categories: images.categories,
+ images,
+ allImagesTotal,
+ isLoading,
+ categories,
+ selectedBoardId,
};
},
defaultSelectorOptions
);
const mainSelector = createSelector(
- [gallerySelector, uiSelector],
- (gallery, ui) => {
+ [gallerySelector, uiSelector, boardsSelector],
+ (gallery, ui, boards) => {
const {
galleryImageMinimumWidth,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
+ galleryView,
} = gallery;
const { shouldPinGallery } = ui;
-
return {
shouldPinGallery,
galleryImageMinimumWidth,
@@ -95,6 +109,8 @@ const mainSelector = createSelector(
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
+ galleryView,
+ selectedBoardId: boards.selectedBoardId,
};
},
defaultSelectorOptions
@@ -126,21 +142,44 @@ const ImageGalleryContent = () => {
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
+ galleryView,
} = useAppSelector(mainSelector);
- const { images, areMoreImagesAvailable, isLoading, categories } =
- useAppSelector(categorySelector);
+ const { images, isLoading, allImagesTotal, categories, selectedBoardId } =
+ useAppSelector(itemSelector);
+
+ const { selectedBoard } = useListAllBoardsQuery(undefined, {
+ selectFromResult: ({ data }) => ({
+ selectedBoard: data?.find((b) => b.board_id === selectedBoardId),
+ }),
+ });
+
+ const filteredImagesTotal = useMemo(
+ () => selectedBoard?.image_count ?? allImagesTotal,
+ [allImagesTotal, selectedBoard?.image_count]
+ );
+
+ const areMoreAvailable = useMemo(() => {
+ return images.length < filteredImagesTotal;
+ }, [images.length, filteredImagesTotal]);
const handleLoadMoreImages = useCallback(() => {
- dispatch(receivedPageOfImages());
- }, [dispatch]);
+ dispatch(
+ receivedPageOfImages({
+ categories,
+ boardId: selectedBoardId,
+ })
+ );
+ }, [categories, dispatch, selectedBoardId]);
const handleEndReached = useMemo(() => {
- if (areMoreImagesAvailable && !isLoading) {
+ if (areMoreAvailable && !isLoading) {
return handleLoadMoreImages;
}
return undefined;
- }, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]);
+ }, [areMoreAvailable, handleLoadMoreImages, isLoading]);
+
+ const { isOpen: isBoardListOpen, onToggle } = useDisclosure();
const handleChangeGalleryImageMinimumWidth = (v: number) => {
dispatch(setGalleryImageMinimumWidth(v));
@@ -172,46 +211,79 @@ const ImageGalleryContent = () => {
const handleClickImagesCategory = useCallback(() => {
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
+ dispatch(setGalleryView('images'));
}, [dispatch]);
const handleClickAssetsCategory = useCallback(() => {
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
+ dispatch(setGalleryView('assets'));
}, [dispatch]);
return (
-
-
-
-
+
+
+ }
+ />
+ }
+ />
+
+ }
- />
- }
- />
-
-
+ variant="ghost"
+ sx={{
+ w: 'full',
+ justifyContent: 'center',
+ alignItems: 'center',
+ px: 2,
+ _hover: {
+ bg: 'base.800',
+ },
+ }}
+ >
+
+ {selectedBoard ? selectedBoard.board_name : 'All Images'}
+
+
+
{
icon={shouldPinGallery ? : }
/>
-
-
- {images.length || areMoreImagesAvailable ? (
+
+
+
+
+
+ {images.length || areMoreAvailable ? (
<>
{shouldUseSingleGalleryColumn ? (
@@ -280,14 +355,12 @@ const ImageGalleryContent = () => {
data={images}
endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)}
- itemContent={(index, image) => (
+ itemContent={(index, item) => (
)}
@@ -302,13 +375,11 @@ const ImageGalleryContent = () => {
List: ListContainer,
}}
scrollerRef={setScroller}
- itemContent={(index, image) => (
+ itemContent={(index, item) => (
)}
/>
@@ -316,12 +387,12 @@ const ImageGalleryContent = () => {
- {areMoreImagesAvailable
+ {areMoreAvailable
? t('gallery.loadMore')
: t('gallery.allImagesLoaded')}
@@ -350,7 +421,7 @@ const ImageGalleryContent = () => {
)}
-
+
);
};
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx
index 892516a3cc..e5cb4cf4a8 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetaDataViewer/ImageMetadataViewer.tsx
@@ -93,19 +93,11 @@ type ImageMetadataViewerProps = {
image: ImageDTO;
};
-// TODO: I don't know if this is needed.
-const memoEqualityCheck = (
- prev: ImageMetadataViewerProps,
- next: ImageMetadataViewerProps
-) => prev.image.image_name === next.image.image_name;
-
-// TODO: Show more interesting information in this component.
-
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
-const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
+const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const {
recallBothPrompts,
@@ -333,8 +325,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
);
-}, memoEqualityCheck);
+};
-ImageMetadataViewer.displayName = 'ImageMetadataViewer';
-
-export default ImageMetadataViewer;
+export default memo(ImageMetadataViewer);
diff --git a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx
index 82e7a0d623..b1f06ad433 100644
--- a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx
@@ -42,7 +42,7 @@ export const nextPrevImageButtonsSelector = createSelector(
}
const currentImageIndex = filteredImageIds.findIndex(
- (i) => i === selectedImage.image_name
+ (i) => i === selectedImage
);
const nextImageIndex = clamp(
@@ -71,6 +71,8 @@ export const nextPrevImageButtonsSelector = createSelector(
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
nextImage,
prevImage,
+ nextImageId,
+ prevImageId,
};
},
{
@@ -84,7 +86,7 @@ const NextPrevImageButtons = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
- const { isOnFirstImage, isOnLastImage, nextImage, prevImage } =
+ const { isOnFirstImage, isOnLastImage, nextImageId, prevImageId } =
useAppSelector(nextPrevImageButtonsSelector);
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
@@ -99,19 +101,19 @@ const NextPrevImageButtons = () => {
}, []);
const handlePrevImage = useCallback(() => {
- dispatch(imageSelected(prevImage));
- }, [dispatch, prevImage]);
+ dispatch(imageSelected(prevImageId));
+ }, [dispatch, prevImageId]);
const handleNextImage = useCallback(() => {
- dispatch(imageSelected(nextImage));
- }, [dispatch, nextImage]);
+ dispatch(imageSelected(nextImageId));
+ }, [dispatch, nextImageId]);
useHotkeys(
'left',
() => {
handlePrevImage();
},
- [prevImage]
+ [prevImageId]
);
useHotkeys(
@@ -119,7 +121,7 @@ const NextPrevImageButtons = () => {
() => {
handleNextImage();
},
- [nextImage]
+ [nextImageId]
);
return (
diff --git a/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx b/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx
new file mode 100644
index 0000000000..7038b4b64f
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/components/SelectedItemOverlay.tsx
@@ -0,0 +1,26 @@
+import { motion } from 'framer-motion';
+
+export const SelectedItemOverlay = () => (
+
+);
diff --git a/invokeai/frontend/web/src/features/gallery/store/boardSelectors.ts b/invokeai/frontend/web/src/features/gallery/store/boardSelectors.ts
new file mode 100644
index 0000000000..3dac2b6e50
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/store/boardSelectors.ts
@@ -0,0 +1,23 @@
+import { createSelector } from '@reduxjs/toolkit';
+import { RootState } from 'app/store/store';
+import { selectBoardsAll } from './boardSlice';
+
+export const boardSelector = (state: RootState) => state.boards.entities;
+
+export const searchBoardsSelector = createSelector(
+ (state: RootState) => state,
+ (state) => {
+ const {
+ boards: { searchText },
+ } = state;
+
+ if (!searchText) {
+ // If no search text provided, return all entities
+ return selectBoardsAll(state);
+ }
+
+ return selectBoardsAll(state).filter((i) =>
+ i.board_name.toLowerCase().includes(searchText.toLowerCase())
+ );
+ }
+);
diff --git a/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts b/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts
new file mode 100644
index 0000000000..8fc9bfa486
--- /dev/null
+++ b/invokeai/frontend/web/src/features/gallery/store/boardSlice.ts
@@ -0,0 +1,47 @@
+import { PayloadAction, createSlice } from '@reduxjs/toolkit';
+import { RootState } from 'app/store/store';
+import { api } from 'services/apiSlice';
+
+type BoardsState = {
+ searchText: string;
+ selectedBoardId?: string;
+ updateBoardModalOpen: boolean;
+};
+
+export const initialBoardsState: BoardsState = {
+ updateBoardModalOpen: false,
+ searchText: '',
+};
+
+const boardsSlice = createSlice({
+ name: 'boards',
+ initialState: initialBoardsState,
+ reducers: {
+ boardIdSelected: (state, action: PayloadAction) => {
+ state.selectedBoardId = action.payload;
+ },
+ setBoardSearchText: (state, action: PayloadAction) => {
+ state.searchText = action.payload;
+ },
+ setUpdateBoardModalOpen: (state, action: PayloadAction) => {
+ state.updateBoardModalOpen = action.payload;
+ },
+ },
+ extraReducers: (builder) => {
+ builder.addMatcher(
+ api.endpoints.deleteBoard.matchFulfilled,
+ (state, action) => {
+ if (action.meta.arg.originalArgs === state.selectedBoardId) {
+ state.selectedBoardId = undefined;
+ }
+ }
+ );
+ },
+});
+
+export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } =
+ boardsSlice.actions;
+
+export const boardsSelector = (state: RootState) => state.boards;
+
+export default boardsSlice.reducer;
diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts
index 4f250a7c3a..b7fc0809a6 100644
--- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts
+++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts
@@ -1,17 +1,16 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
-import { ImageDTO } from 'services/api';
import { imageUpserted } from './imagesSlice';
-import { imageUrlsReceived } from 'services/thunks/image';
type GalleryImageObjectFitType = 'contain' | 'cover';
export interface GalleryState {
- selectedImage?: ImageDTO;
+ selectedImage?: string;
galleryImageMinimumWidth: number;
galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean;
shouldUseSingleGalleryColumn: boolean;
+ galleryView: 'images' | 'assets' | 'boards';
}
export const initialGalleryState: GalleryState = {
@@ -19,13 +18,14 @@ export const initialGalleryState: GalleryState = {
galleryImageObjectFit: 'cover',
shouldAutoSwitchToNewImages: true,
shouldUseSingleGalleryColumn: false,
+ galleryView: 'images',
};
export const gallerySlice = createSlice({
name: 'gallery',
initialState: initialGalleryState,
reducers: {
- imageSelected: (state, action: PayloadAction) => {
+ imageSelected: (state, action: PayloadAction) => {
state.selectedImage = action.payload;
// TODO: if the user selects an image, disable the auto switch?
// state.shouldAutoSwitchToNewImages = false;
@@ -48,6 +48,12 @@ export const gallerySlice = createSlice({
) => {
state.shouldUseSingleGalleryColumn = action.payload;
},
+ setGalleryView: (
+ state,
+ action: PayloadAction<'images' | 'assets' | 'boards'>
+ ) => {
+ state.galleryView = action.payload;
+ },
},
extraReducers: (builder) => {
builder.addCase(imageUpserted, (state, action) => {
@@ -55,17 +61,17 @@ export const gallerySlice = createSlice({
state.shouldAutoSwitchToNewImages &&
action.payload.image_category === 'general'
) {
- state.selectedImage = action.payload;
+ state.selectedImage = action.payload.image_name;
}
});
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_url, thumbnail_url } = action.payload;
+ // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
+ // const { image_name, image_url, thumbnail_url } = action.payload;
- if (state.selectedImage?.image_name === image_name) {
- state.selectedImage.image_url = image_url;
- state.selectedImage.thumbnail_url = thumbnail_url;
- }
- });
+ // if (state.selectedImage?.image_name === image_name) {
+ // state.selectedImage.image_url = image_url;
+ // state.selectedImage.thumbnail_url = thumbnail_url;
+ // }
+ // });
},
});
@@ -75,6 +81,7 @@ export const {
setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn,
+ setGalleryView,
} = gallerySlice.actions;
export default gallerySlice.reducer;
diff --git a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts
index 9c18380c54..25a3341532 100644
--- a/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts
+++ b/invokeai/frontend/web/src/features/gallery/store/imagesSlice.ts
@@ -11,7 +11,6 @@ import { dateComparator } from 'common/util/dateComparator';
import { keyBy } from 'lodash-es';
import {
imageDeleted,
- imageMetadataReceived,
imageUrlsReceived,
receivedPageOfImages,
} from 'services/thunks/image';
@@ -74,11 +73,21 @@ const imagesSlice = createSlice({
});
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
state.isLoading = false;
+ const { boardId, categories, imageOrigin, isIntermediate } =
+ action.meta.arg;
+
const { items, offset, limit, total } = action.payload;
+ imagesAdapter.upsertMany(state, items);
+
+ if (!categories?.includes('general') || boardId) {
+ // need to skip updating the total images count if the images recieved were for a specific board
+ // TODO: this doesn't work when on the Asset tab/category...
+ return;
+ }
+
state.offset = offset;
state.limit = limit;
state.total = total;
- imagesAdapter.upsertMany(state, items);
});
builder.addCase(imageDeleted.pending, (state, action) => {
// Image deleted
@@ -154,3 +163,16 @@ export const selectFilteredImagesIds = createSelector(
.map((i) => i.image_name);
}
);
+
+// export const selectImageById = createSelector(
+// (state: RootState, imageId) => state,
+// (state) => {
+// const {
+// images: { categories },
+// } = state;
+
+// return selectImagesAll(state)
+// .filter((i) => categories.includes(i.image_category))
+// .map((i) => i.image_name);
+// }
+// );
diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
index dc4590e6ca..c5a3a1970b 100644
--- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx
@@ -11,6 +11,8 @@ import { FieldComponentProps } from './types';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
import { Flex } from '@chakra-ui/react';
+import { useGetImageDTOQuery } from 'services/apiSlice';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
const ImageInputFieldComponent = (
props: FieldComponentProps
@@ -19,9 +21,16 @@ const ImageInputFieldComponent = (
const dispatch = useAppDispatch();
+ const {
+ data: image,
+ isLoading,
+ isError,
+ isSuccess,
+ } = useGetImageDTOQuery(field.value ?? skipToken);
+
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
- if (field.value?.image_name === droppedImage.image_name) {
+ if (field.value === droppedImage.image_name) {
return;
}
@@ -29,11 +38,11 @@ const ImageInputFieldComponent = (
fieldValueChanged({
nodeId,
fieldName: field.name,
- value: droppedImage,
+ value: droppedImage.image_name,
})
);
},
- [dispatch, field.name, field.value?.image_name, nodeId]
+ [dispatch, field.name, field.value, nodeId]
);
const handleReset = useCallback(() => {
@@ -56,7 +65,7 @@ const ImageInputFieldComponent = (
}}
>
{
- return { allModelNames };
- // return map(modelList, (_, name) => name);
- },
- {
- memoizeOptions: {
- resultEqualityCheck: isEqual,
- },
- }
-);
+import { memo, useCallback, useEffect, useMemo } from 'react';
+import { FieldComponentProps } from './types';
+import { forEach, isString } from 'lodash-es';
+import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
+import IAIMantineSelect from 'common/components/IAIMantineSelect';
+import { useTranslation } from 'react-i18next';
+import { useListModelsQuery } from 'services/apiSlice';
const ModelInputFieldComponent = (
props: FieldComponentProps
@@ -30,28 +20,82 @@ const ModelInputFieldComponent = (
const { nodeId, field } = props;
const dispatch = useAppDispatch();
+ const { t } = useTranslation();
- const { allModelNames } = useAppSelector(availableModelsSelector);
+ const { data: pipelineModels } = useListModelsQuery({
+ model_type: 'pipeline',
+ });
- const handleValueChanged = (e: ChangeEvent) => {
- dispatch(
- fieldValueChanged({
- nodeId,
- fieldName: field.name,
- value: e.target.value,
- })
- );
- };
+ const data = useMemo(() => {
+ if (!pipelineModels) {
+ return [];
+ }
+
+ const data: SelectItem[] = [];
+
+ forEach(pipelineModels.entities, (model, id) => {
+ if (!model) {
+ return;
+ }
+
+ data.push({
+ value: id,
+ label: model.name,
+ group: BASE_MODEL_NAME_MAP[model.base_model],
+ });
+ });
+
+ return data;
+ }, [pipelineModels]);
+
+ const selectedModel = useMemo(
+ () => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
+ [pipelineModels?.entities, pipelineModels?.ids, field.value]
+ );
+
+ const handleValueChanged = useCallback(
+ (v: string | null) => {
+ if (!v) {
+ return;
+ }
+
+ dispatch(
+ fieldValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value: v,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
+
+ useEffect(() => {
+ if (field.value && pipelineModels?.ids.includes(field.value)) {
+ return;
+ }
+
+ const firstModel = pipelineModels?.ids[0];
+
+ if (!isString(firstModel)) {
+ return;
+ }
+
+ handleValueChanged(firstModel);
+ }, [field.value, handleValueChanged, pipelineModels?.ids]);
return (
-
+ />
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
index 5425d1cfd5..341f0c467b 100644
--- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
@@ -101,21 +101,6 @@ const nodesSlice = createSlice({
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload;
});
-
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_url, thumbnail_url } = action.payload;
-
- state.nodes.forEach((node) => {
- forEach(node.data.inputs, (input) => {
- if (input.type === 'image') {
- if (input.value?.image_name === image_name) {
- input.value.image_url = image_url;
- input.value.thumbnail_url = thumbnail_url;
- }
- }
- });
- });
- });
},
});
diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts
index 5e140b6eef..acad10cf48 100644
--- a/invokeai/frontend/web/src/features/nodes/types/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/types.ts
@@ -214,7 +214,7 @@ export type VaeInputFieldValue = FieldValueBase & {
export type ImageInputFieldValue = FieldValueBase & {
type: 'image';
- value?: ImageDTO;
+ value?: string;
};
export type ModelInputFieldValue = FieldValueBase & {
diff --git a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts
index dd5a97e2f1..314af85193 100644
--- a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts
@@ -65,15 +65,13 @@ export const addControlNetToLinearGraph = (
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
- const { image_name } = processedControlImage;
controlNetNode.image = {
- image_name,
+ image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
- const { image_name } = controlImage;
controlNetNode.image = {
- image_name,
+ image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts
index efaeaddff2..ccdc3e0a27 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts
@@ -23,6 +23,7 @@ import {
} from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
+import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' });
@@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
const {
positivePrompt,
negativePrompt,
- model: model_name,
+ model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
+ const model = modelIdToPipelineModelField(modelId);
+
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = (
id: NOISE,
},
[MODEL_LOADER]: {
- type: 'sd1_model_loader',
+ type: 'pipeline_model_loader',
id: MODEL_LOADER,
- model_name,
+ model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts
index 785e1d2fdb..9ffe85b3c9 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts
@@ -17,6 +17,7 @@ import {
INPAINT_GRAPH,
INPAINT,
} from './constants';
+import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' });
@@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = (
const {
positivePrompt,
negativePrompt,
- model: model_name,
+ model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
+ const model = modelIdToPipelineModelField(modelId);
+
const graph: NonNullableGraph = {
id: INPAINT_GRAPH,
nodes: {
@@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = (
prompt: negativePrompt,
},
[MODEL_LOADER]: {
- type: 'sd1_model_loader',
+ type: 'pipeline_model_loader',
id: MODEL_LOADER,
- model_name,
+ model,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts
index ca0e56e849..920cb5bf02 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts
@@ -14,6 +14,7 @@ import {
TEXT_TO_LATENTS,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
+import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/**
* Builds the Canvas tab's Text to Image graph.
@@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
const {
positivePrompt,
negativePrompt,
- model: model_name,
+ model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
+ const model = modelIdToPipelineModelField(modelId);
+
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = (
steps,
},
[MODEL_LOADER]: {
- type: 'sd1_model_loader',
+ type: 'pipeline_model_loader',
id: MODEL_LOADER,
- model_name,
+ model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts
index 1f2c8327e0..8425ac043a 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts
@@ -22,6 +22,7 @@ import {
} from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
+import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' });
@@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = (
const {
positivePrompt,
negativePrompt,
- model: model_name,
+ model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state');
}
+ const model = modelIdToPipelineModelField(modelId);
+
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: IMAGE_TO_IMAGE_GRAPH,
@@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = (
id: NOISE,
},
[MODEL_LOADER]: {
- type: 'sd1_model_loader',
+ type: 'pipeline_model_loader',
id: MODEL_LOADER,
- model_name,
+ model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
@@ -274,7 +277,7 @@ export const buildLinearImageToImageGraph = (
id: RESIZE,
type: 'img_resize',
image: {
- image_name: initialImage.image_name,
+ image_name: initialImage.imageName,
},
is_intermediate: true,
width,
@@ -311,7 +314,7 @@ export const buildLinearImageToImageGraph = (
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
- image_name: initialImage.image_name,
+ image_name: initialImage.imageName,
});
// Pass the image's dimensions to the `NOISE` node
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts
index c179a89504..973acdfb77 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts
@@ -1,6 +1,10 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
-import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
+import {
+ BaseModelType,
+ RandomIntInvocation,
+ RangeOfSizeInvocation,
+} from 'services/api';
import {
ITERATE,
LATENTS_TO_IMAGE,
@@ -14,6 +18,7 @@ import {
TEXT_TO_LATENTS,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
+import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
type TextToImageGraphOverrides = {
width: number;
@@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = (
const {
positivePrompt,
negativePrompt,
- model: model_name,
+ model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = (
shouldRandomizeSeed,
} = state.generation;
+ const model = modelIdToPipelineModelField(modelId);
+
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = (
steps,
},
[MODEL_LOADER]: {
- type: 'sd1_model_loader',
+ type: 'pipeline_model_loader',
id: MODEL_LOADER,
- model_name,
+ model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
index 6a700d4813..072b1a53fd 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts
@@ -1,9 +1,10 @@
import { Graph } from 'services/api';
import { v4 as uuidv4 } from 'uuid';
-import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es';
+import { cloneDeep, omit, reduce } from 'lodash-es';
import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types';
import { AnyInvocation } from 'services/events/types';
+import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/**
* We need to do special handling for some fields
@@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
}
}
+ if (field.type === 'model') {
+ if (field.value) {
+ return modelIdToPipelineModelField(field.value);
+ }
+ }
+
return field.value;
};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts
index 39e0080d11..7d4469bc41 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts
@@ -7,7 +7,7 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate';
-export const MODEL_LOADER = 'model_loader';
+export const MODEL_LOADER = 'pipeline_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';
diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts
new file mode 100644
index 0000000000..bbcd8d9bc6
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToPipelineModelField.ts
@@ -0,0 +1,18 @@
+import { BaseModelType, PipelineModelField } from 'services/api';
+
+/**
+ * Crudely converts a model id to a pipeline model field
+ * TODO: Make better
+ */
+export const modelIdToPipelineModelField = (
+ modelId: string
+): PipelineModelField => {
+ const [base_model, model_type, model_name] = modelId.split('/');
+
+ const field: PipelineModelField = {
+ base_model: base_model as BaseModelType,
+ model_name,
+ };
+
+ return field;
+};
diff --git a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts
index e29b46af70..6ebd014876 100644
--- a/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/nodeBuilders/buildImageToImageNode.ts
@@ -57,7 +57,7 @@ export const buildImg2ImgNode = (
}
imageToImageNode.image = {
- image_name: initialImage.image_name,
+ image_name: initialImage.imageName,
};
}
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx
index 65da89b94d..5092893eed 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSchedulerAndModel.tsx
@@ -6,7 +6,7 @@ import ParamScheduler from './ParamScheduler';
const ParamSchedulerAndModel = () => {
return (
-
+
diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx
index fa415074e6..fbfa00e2a1 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx
@@ -10,7 +10,9 @@ import { generationSelector } from 'features/parameters/store/generationSelector
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
-import { IAIImageFallback } from 'common/components/IAIImageFallback';
+import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
+import { useGetImageDTOQuery } from 'services/apiSlice';
+import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector(
[generationSelector],
@@ -27,14 +29,21 @@ const InitialImagePreview = () => {
const { initialImage } = useAppSelector(selector);
const dispatch = useAppDispatch();
+ const {
+ data: image,
+ isLoading,
+ isError,
+ isSuccess,
+ } = useGetImageDTOQuery(initialImage?.imageName ?? skipToken);
+
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
- if (droppedImage.image_name === initialImage?.image_name) {
+ if (droppedImage.image_name === initialImage?.imageName) {
return;
}
dispatch(initialImageChanged(droppedImage));
},
- [dispatch, initialImage?.image_name]
+ [dispatch, initialImage]
);
const handleReset = useCallback(() => {
@@ -53,10 +62,10 @@ const InitialImagePreview = () => {
}}
>
}
+ fallback={}
postUploadAction={{ type: 'SET_INITIAL_IMAGE' }}
withResetIcon
/>
diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
index 961ea1b8af..e7dcbf0d83 100644
--- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
+++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts
@@ -1,10 +1,9 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
+import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
import { configChanged } from 'features/system/store/configSlice';
-import { clamp, sortBy } from 'lodash-es';
+import { clamp } from 'lodash-es';
import { ImageDTO } from 'services/api';
-import { imageUrlsReceived } from 'services/thunks/image';
-import { receivedModels } from 'services/thunks/model';
import {
CfgScaleParam,
HeightParam,
@@ -17,14 +16,13 @@ import {
StrengthParam,
WidthParam,
} from './parameterZodSchemas';
-import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
export interface GenerationState {
cfgScale: CfgScaleParam;
height: HeightParam;
img2imgStrength: StrengthParam;
infillMethod: string;
- initialImage?: ImageDTO;
+ initialImage?: { imageName: string; width: number; height: number };
iterations: number;
perlin: number;
positivePrompt: PositivePromptParam;
@@ -212,35 +210,20 @@ export const generationSlice = createSlice({
state.shouldUseNoiseSettings = action.payload;
},
initialImageChanged: (state, action: PayloadAction) => {
- state.initialImage = action.payload;
+ const { image_name, width, height } = action.payload;
+ state.initialImage = { imageName: image_name, width, height };
},
modelSelected: (state, action: PayloadAction) => {
state.model = action.payload;
},
},
extraReducers: (builder) => {
- builder.addCase(receivedModels.fulfilled, (state, action) => {
- if (!state.model) {
- const firstModel = sortBy(action.payload, 'name')[0];
- state.model = firstModel.name;
- }
- });
-
builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) {
state.model = defaultModel;
}
});
-
- builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
- const { image_name, image_url, thumbnail_url } = action.payload;
-
- if (state.initialImage?.image_name === image_name) {
- state.initialImage.image_url = image_url;
- state.initialImage.thumbnail_url = thumbnail_url;
- }
- });
},
});
diff --git a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
index 61567d3fb8..48eb309e7d 100644
--- a/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
+++ b/invokeai/frontend/web/src/features/parameters/store/parameterZodSchemas.ts
@@ -154,3 +154,17 @@ export type StrengthParam = z.infer;
*/
export const isValidStrength = (val: unknown): val is StrengthParam =>
zStrength.safeParse(val).success;
+
+// /**
+// * Zod schema for BaseModelType
+// */
+// export const zBaseModelType = z.enum(['sd-1', 'sd-2']);
+// /**
+// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI.
+// */
+// export type BaseModelType = z.infer;
+// /**
+// * Validates/type-guards a value as a base model type
+// */
+// export const isValidBaseModelType = (val: unknown): val is BaseModelType =>
+// zBaseModelType.safeParse(val).success;
diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx
index a38ab150dd..43de14d507 100644
--- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx
+++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx
@@ -1,44 +1,59 @@
-import { createSelector } from '@reduxjs/toolkit';
-import { isEqual } from 'lodash-es';
-import { memo, useCallback } from 'react';
+import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
-import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
-import IAIMantineSelect, {
- IAISelectDataType,
-} from 'common/components/IAIMantineSelect';
-import { generationSelector } from 'features/parameters/store/generationSelectors';
+import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { modelSelected } from 'features/parameters/store/generationSlice';
-import { selectModelsAll, selectModelsById } from '../store/modelSlice';
-const selector = createSelector(
- [(state: RootState) => state, generationSelector],
- (state, generation) => {
- const selectedModel = selectModelsById(state, generation.model);
+import { forEach, isString } from 'lodash-es';
+import { SelectItem } from '@mantine/core';
+import { RootState } from 'app/store/store';
+import { useListModelsQuery } from 'services/apiSlice';
- const modelData = selectModelsAll(state)
- .map((m) => ({
- value: m.name,
- label: m.name,
- }))
- .sort((a, b) => a.label.localeCompare(b.label));
- return {
- selectedModel,
- modelData,
- };
- },
- {
- memoizeOptions: {
- resultEqualityCheck: isEqual,
- },
- }
-);
+export const MODEL_TYPE_MAP = {
+ 'sd-1': 'Stable Diffusion 1.x',
+ 'sd-2': 'Stable Diffusion 2.x',
+};
const ModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
- const { selectedModel, modelData } = useAppSelector(selector);
+
+ const selectedModelId = useAppSelector(
+ (state: RootState) => state.generation.model
+ );
+
+ const { data: pipelineModels } = useListModelsQuery({
+ model_type: 'pipeline',
+ });
+
+ const data = useMemo(() => {
+ if (!pipelineModels) {
+ return [];
+ }
+
+ const data: SelectItem[] = [];
+
+ forEach(pipelineModels.entities, (model, id) => {
+ if (!model) {
+ return;
+ }
+
+ data.push({
+ value: id,
+ label: model.name,
+ group: MODEL_TYPE_MAP[model.base_model],
+ });
+ });
+
+ return data;
+ }, [pipelineModels]);
+
+ const selectedModel = useMemo(
+ () => pipelineModels?.entities[selectedModelId],
+ [pipelineModels?.entities, selectedModelId]
+ );
+
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
@@ -49,13 +64,27 @@ const ModelSelect = () => {
[dispatch]
);
+ useEffect(() => {
+ if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
+ return;
+ }
+
+ const firstModel = pipelineModels?.ids[0];
+
+ if (!isString(firstModel)) {
+ return;
+ }
+
+ handleChangeModel(firstModel);
+ }, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
+
return (
);
diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx
index 2e0b3234c7..26c11604e1 100644
--- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx
+++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx
@@ -1,6 +1,5 @@
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { RootState } from 'app/store/store';
-
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
@@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({
export default function SettingsSchedulers() {
const dispatch = useAppDispatch();
+
const { t } = useTranslation();
const enabledSchedulers = useAppSelector(
diff --git a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts
index 193420e29c..8ba5731a5b 100644
--- a/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts
+++ b/invokeai/frontend/web/src/features/system/hooks/useIsApplicationReady.ts
@@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors';
const isApplicationReadySelector = createSelector(
[systemSelector, configSelector],
(system, config) => {
- const { wereModelsReceived, wasSchemaParsed } = system;
+ const { wasSchemaParsed } = system;
const { disabledTabs } = config;
return {
disabledTabs,
- wereModelsReceived,
wasSchemaParsed,
};
}
@@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector(
* Checks if the application is ready to be used, i.e. if the initial startup process is finished.
*/
export const useIsApplicationReady = () => {
- const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector(
+ const { disabledTabs, wasSchemaParsed } = useAppSelector(
isApplicationReadySelector
);
const isApplicationReady = useMemo(() => {
- if (!wereModelsReceived) {
- return false;
- }
-
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
return false;
}
return true;
- }, [disabledTabs, wereModelsReceived, wasSchemaParsed]);
+ }, [disabledTabs, wasSchemaParsed]);
return isApplicationReady;
};
diff --git a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts b/invokeai/frontend/web/src/features/system/store/modelSelectors.ts
deleted file mode 100644
index f857bc85bc..0000000000
--- a/invokeai/frontend/web/src/features/system/store/modelSelectors.ts
+++ /dev/null
@@ -1,3 +0,0 @@
-import { RootState } from 'app/store/store';
-
-export const modelSelector = (state: RootState) => state.models;
diff --git a/invokeai/frontend/web/src/features/system/store/modelSlice.ts b/invokeai/frontend/web/src/features/system/store/modelSlice.ts
deleted file mode 100644
index ed38425872..0000000000
--- a/invokeai/frontend/web/src/features/system/store/modelSlice.ts
+++ /dev/null
@@ -1,47 +0,0 @@
-import { createEntityAdapter } from '@reduxjs/toolkit';
-import { createSlice } from '@reduxjs/toolkit';
-import { RootState } from 'app/store/store';
-import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
-import { receivedModels } from 'services/thunks/model';
-
-export type Model = (CkptModelInfo | DiffusersModelInfo) & {
- name: string;
-};
-
-export const modelsAdapter = createEntityAdapter({
- selectId: (model) => model.name,
- sortComparer: (a, b) => a.name.localeCompare(b.name),
-});
-
-export const initialModelsState = modelsAdapter.getInitialState();
-
-export type ModelsState = typeof initialModelsState;
-
-export const modelsSlice = createSlice({
- name: 'models',
- initialState: initialModelsState,
- reducers: {
- modelAdded: modelsAdapter.upsertOne,
- },
- extraReducers(builder) {
- /**
- * Received Models - FULFILLED
- */
- builder.addCase(receivedModels.fulfilled, (state, action) => {
- const models = action.payload;
- modelsAdapter.setAll(state, models);
- });
- },
-});
-
-export const {
- selectAll: selectModelsAll,
- selectById: selectModelsById,
- selectEntities: selectModelsEntities,
- selectIds: selectModelsIds,
- selectTotal: selectModelsTotal,
-} = modelsAdapter.getSelectors((state) => state.models);
-
-export const { modelAdded } = modelsSlice.actions;
-
-export default modelsSlice.reducer;
diff --git a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts b/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts
deleted file mode 100644
index aa9fb057e1..0000000000
--- a/invokeai/frontend/web/src/features/system/store/modelsPersistDenylist.ts
+++ /dev/null
@@ -1,6 +0,0 @@
-import { ModelsState } from './modelSlice';
-
-/**
- * Models slice persist denylist
- */
-export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids'];
diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts
index b17f497f6c..688f69c1f7 100644
--- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts
+++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts
@@ -1,20 +1,12 @@
import { UseToastOptions } from '@chakra-ui/react';
-import { PayloadAction } from '@reduxjs/toolkit';
-import { createSlice } from '@reduxjs/toolkit';
+import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
-import { ProgressImage } from 'services/events/types';
-import { makeToast } from '../../../app/components/Toaster';
-import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
-import { receivedModels } from 'services/thunks/model';
-import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
-import { LogLevelName } from 'roarr';
import { InvokeLogLevel } from 'app/logging/useLogger';
-import { TFuncKey } from 'i18next';
-import { t } from 'i18next';
import { userInvoked } from 'app/store/actions';
-import { LANGUAGES } from '../components/LanguagePicker';
-import { imageUploaded } from 'services/thunks/image';
+import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
+import { TFuncKey, t } from 'i18next';
+import { LogLevelName } from 'roarr';
import {
appSocketConnected,
appSocketDisconnected,
@@ -26,6 +18,11 @@ import {
appSocketSubscribed,
appSocketUnsubscribed,
} from 'services/events/actions';
+import { ProgressImage } from 'services/events/types';
+import { imageUploaded } from 'services/thunks/image';
+import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
+import { makeToast } from '../../../app/components/Toaster';
+import { LANGUAGES } from '../components/LanguagePicker';
export type CancelStrategy = 'immediate' | 'scheduled';
@@ -95,6 +92,7 @@ export interface SystemState {
shouldAntialiasProgressImage: boolean;
language: keyof typeof LANGUAGES;
isUploading: boolean;
+ boardIdToAddTo?: string;
}
export const initialSystemState: SystemState = {
@@ -225,6 +223,7 @@ export const systemSlice = createSlice({
*/
builder.addCase(appSocketSubscribed, (state, action) => {
state.sessionId = action.payload.sessionId;
+ state.boardIdToAddTo = action.payload.boardId;
state.canceledSession = '';
});
@@ -233,6 +232,7 @@ export const systemSlice = createSlice({
*/
builder.addCase(appSocketUnsubscribed, (state) => {
state.sessionId = null;
+ state.boardIdToAddTo = undefined;
});
/**
@@ -376,13 +376,6 @@ export const systemSlice = createSlice({
);
});
- /**
- * Received available models from the backend
- */
- builder.addCase(receivedModels.fulfilled, (state) => {
- state.wereModelsReceived = true;
- });
-
/**
* OpenAPI schema was parsed
*/
diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts
index 7481a5daad..8ce42494e5 100644
--- a/invokeai/frontend/web/src/services/api/index.ts
+++ b/invokeai/frontend/web/src/services/api/index.ts
@@ -8,6 +8,10 @@ export type { OpenAPIConfig } from './core/OpenAPI';
export type { AddInvocation } from './models/AddInvocation';
export type { BaseModelType } from './models/BaseModelType';
+export type { BoardChanges } from './models/BoardChanges';
+export type { BoardDTO } from './models/BoardDTO';
+export type { Body_create_board_image } from './models/Body_create_board_image';
+export type { Body_remove_board_image } from './models/Body_remove_board_image';
export type { Body_upload_image } from './models/Body_upload_image';
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
export type { CkptModelInfo } from './models/CkptModelInfo';
@@ -21,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField';
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
export type { ControlField } from './models/ControlField';
export type { ControlNetInvocation } from './models/ControlNetInvocation';
+export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
+export type { ControlNetModelFormat } from './models/ControlNetModelFormat';
export type { ControlOutput } from './models/ControlOutput';
export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
@@ -63,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput';
-export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
-export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config';
-export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
-export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
-export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
-export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
-export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
-export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config';
export type { IterateInvocation } from './models/IterateInvocation';
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
export type { LatentsField } from './models/LatentsField';
@@ -83,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { LoraInfo } from './models/LoraInfo';
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
+export type { LoRAModelConfig } from './models/LoRAModelConfig';
+export type { LoRAModelFormat } from './models/LoRAModelFormat';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput';
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
@@ -98,12 +98,15 @@ export type { MultiplyInvocation } from './models/MultiplyInvocation';
export type { NoiseInvocation } from './models/NoiseInvocation';
export type { NoiseOutput } from './models/NoiseOutput';
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
+export type { OffsetPaginatedResults_BoardDTO_ } from './models/OffsetPaginatedResults_BoardDTO_';
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
+export type { PipelineModelField } from './models/PipelineModelField';
+export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation';
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation';
@@ -115,20 +118,28 @@ export type { ResourceOrigin } from './models/ResourceOrigin';
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
-export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
-export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
export type { ShowImageInvocation } from './models/ShowImageInvocation';
+export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
+export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
+export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat';
+export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
+export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
+export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat';
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
export type { SubModelType } from './models/SubModelType';
export type { SubtractInvocation } from './models/SubtractInvocation';
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
+export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
export type { UNetField } from './models/UNetField';
export type { UpscaleInvocation } from './models/UpscaleInvocation';
export type { VaeField } from './models/VaeField';
+export type { VaeModelConfig } from './models/VaeModelConfig';
+export type { VaeModelFormat } from './models/VaeModelFormat';
export type { VaeRepo } from './models/VaeRepo';
export type { ValidationError } from './models/ValidationError';
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
+export { BoardsService } from './services/BoardsService';
export { ImagesService } from './services/ImagesService';
export { ModelsService } from './services/ModelsService';
export { SessionsService } from './services/SessionsService';
diff --git a/invokeai/frontend/web/src/services/api/models/BoardChanges.ts b/invokeai/frontend/web/src/services/api/models/BoardChanges.ts
new file mode 100644
index 0000000000..fb2bfa0cd9
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/BoardChanges.ts
@@ -0,0 +1,15 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+export type BoardChanges = {
+ /**
+ * The board's new name.
+ */
+ board_name?: string;
+ /**
+ * The name of the board's new cover image.
+ */
+ cover_image_name?: string;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/BoardDTO.ts b/invokeai/frontend/web/src/services/api/models/BoardDTO.ts
new file mode 100644
index 0000000000..bbcc6f1dd6
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/BoardDTO.ts
@@ -0,0 +1,38 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * Deserialized board record with cover image URL and image count.
+ */
+export type BoardDTO = {
+ /**
+ * The unique ID of the board.
+ */
+ board_id: string;
+ /**
+ * The name of the board.
+ */
+ board_name: string;
+ /**
+ * The created timestamp of the board.
+ */
+ created_at: string;
+ /**
+ * The updated timestamp of the board.
+ */
+ updated_at: string;
+ /**
+ * The deleted timestamp of the board.
+ */
+ deleted_at?: string;
+ /**
+ * The name of the board's cover image.
+ */
+ cover_image_name?: string;
+ /**
+ * The number of images in the board.
+ */
+ image_count: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/Body_create_board_image.ts b/invokeai/frontend/web/src/services/api/models/Body_create_board_image.ts
new file mode 100644
index 0000000000..47f8537eaa
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/Body_create_board_image.ts
@@ -0,0 +1,15 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+export type Body_create_board_image = {
+ /**
+ * The id of the board to add to
+ */
+ board_id: string;
+ /**
+ * The name of the image to add
+ */
+ image_name: string;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/Body_remove_board_image.ts b/invokeai/frontend/web/src/services/api/models/Body_remove_board_image.ts
new file mode 100644
index 0000000000..6f5a3652d0
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/Body_remove_board_image.ts
@@ -0,0 +1,15 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+export type Body_remove_board_image = {
+ /**
+ * The id of the board
+ */
+ board_id: string;
+ /**
+ * The name of the image to remove
+ */
+ image_name: string;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts b/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts
new file mode 100644
index 0000000000..60e2958f5c
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ControlNetModelConfig.ts
@@ -0,0 +1,18 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { BaseModelType } from './BaseModelType';
+import type { ControlNetModelFormat } from './ControlNetModelFormat';
+import type { ModelError } from './ModelError';
+
+export type ControlNetModelConfig = {
+ name: string;
+ base_model: BaseModelType;
+ type: 'controlnet';
+ path: string;
+ description?: string;
+ model_format: ControlNetModelFormat;
+ error?: ModelError;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts b/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts
new file mode 100644
index 0000000000..500b3e8f8c
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/ControlNetModelFormat.ts
@@ -0,0 +1,8 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * An enumeration.
+ */
+export type ControlNetModelFormat = 'checkpoint' | 'diffusers';
diff --git a/invokeai/frontend/web/src/services/api/models/Graph.ts b/invokeai/frontend/web/src/services/api/models/Graph.ts
index e148954f16..5fba3d8311 100644
--- a/invokeai/frontend/web/src/services/api/models/Graph.ts
+++ b/invokeai/frontend/web/src/services/api/models/Graph.ts
@@ -49,6 +49,7 @@ import type { OpenposeImageProcessorInvocation } from './OpenposeImageProcessorI
import type { ParamFloatInvocation } from './ParamFloatInvocation';
import type { ParamIntInvocation } from './ParamIntInvocation';
import type { PidiImageProcessorInvocation } from './PidiImageProcessorInvocation';
+import type { PipelineModelLoaderInvocation } from './PipelineModelLoaderInvocation';
import type { RandomIntInvocation } from './RandomIntInvocation';
import type { RandomRangeInvocation } from './RandomRangeInvocation';
import type { RangeInvocation } from './RangeInvocation';
@@ -56,8 +57,6 @@ import type { RangeOfSizeInvocation } from './RangeOfSizeInvocation';
import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
-import type { SD1ModelLoaderInvocation } from './SD1ModelLoaderInvocation';
-import type { SD2ModelLoaderInvocation } from './SD2ModelLoaderInvocation';
import type { ShowImageInvocation } from './ShowImageInvocation';
import type { StepParamEasingInvocation } from './StepParamEasingInvocation';
import type { SubtractInvocation } from './SubtractInvocation';
@@ -73,7 +72,7 @@ export type Graph = {
/**
* The nodes in this graph
*/
- nodes?: Record;
+ nodes?: Record;
/**
* The connections between nodes and their fields in this graph
*/
diff --git a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts
index 5399d16b8f..4e273e8854 100644
--- a/invokeai/frontend/web/src/services/api/models/ImageDTO.ts
+++ b/invokeai/frontend/web/src/services/api/models/ImageDTO.ts
@@ -7,7 +7,7 @@ import type { ImageMetadata } from './ImageMetadata';
import type { ResourceOrigin } from './ResourceOrigin';
/**
- * Deserialized image record, enriched for the frontend with URLs.
+ * Deserialized image record, enriched for the frontend.
*/
export type ImageDTO = {
/**
@@ -66,5 +66,9 @@ export type ImageDTO = {
* A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.
*/
metadata?: ImageMetadata;
+ /**
+ * The id of the board the image belongs to, if one exists.
+ */
+ board_id?: string;
};
diff --git a/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts b/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts
new file mode 100644
index 0000000000..184a266169
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/LoRAModelConfig.ts
@@ -0,0 +1,18 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { BaseModelType } from './BaseModelType';
+import type { LoRAModelFormat } from './LoRAModelFormat';
+import type { ModelError } from './ModelError';
+
+export type LoRAModelConfig = {
+ name: string;
+ base_model: BaseModelType;
+ type: 'lora';
+ path: string;
+ description?: string;
+ model_format: LoRAModelFormat;
+ error?: ModelError;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts b/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts
new file mode 100644
index 0000000000..829f8a7c57
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/LoRAModelFormat.ts
@@ -0,0 +1,8 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * An enumeration.
+ */
+export type LoRAModelFormat = 'lycoris' | 'diffusers';
diff --git a/invokeai/frontend/web/src/services/api/models/ModelsList.ts b/invokeai/frontend/web/src/services/api/models/ModelsList.ts
index a2d88d1967..9186db3e29 100644
--- a/invokeai/frontend/web/src/services/api/models/ModelsList.ts
+++ b/invokeai/frontend/web/src/services/api/models/ModelsList.ts
@@ -2,16 +2,16 @@
/* tslint:disable */
/* eslint-disable */
-import type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
-import type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './invokeai__backend__model_management__models__lora__LoRAModel__Config';
-import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
-import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
-import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
-import type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
-import type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
-import type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './invokeai__backend__model_management__models__vae__VaeModel__Config';
+import type { ControlNetModelConfig } from './ControlNetModelConfig';
+import type { LoRAModelConfig } from './LoRAModelConfig';
+import type { StableDiffusion1ModelCheckpointConfig } from './StableDiffusion1ModelCheckpointConfig';
+import type { StableDiffusion1ModelDiffusersConfig } from './StableDiffusion1ModelDiffusersConfig';
+import type { StableDiffusion2ModelCheckpointConfig } from './StableDiffusion2ModelCheckpointConfig';
+import type { StableDiffusion2ModelDiffusersConfig } from './StableDiffusion2ModelDiffusersConfig';
+import type { TextualInversionModelConfig } from './TextualInversionModelConfig';
+import type { VaeModelConfig } from './VaeModelConfig';
export type ModelsList = {
- models: Record>>;
+ models: Array<(StableDiffusion1ModelCheckpointConfig | StableDiffusion1ModelDiffusersConfig | VaeModelConfig | LoRAModelConfig | ControlNetModelConfig | TextualInversionModelConfig | StableDiffusion2ModelCheckpointConfig | StableDiffusion2ModelDiffusersConfig)>;
};
diff --git a/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_BoardDTO_.ts b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_BoardDTO_.ts
new file mode 100644
index 0000000000..2e4734f469
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/OffsetPaginatedResults_BoardDTO_.ts
@@ -0,0 +1,28 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { BoardDTO } from './BoardDTO';
+
+/**
+ * Offset-paginated results
+ */
+export type OffsetPaginatedResults_BoardDTO_ = {
+ /**
+ * Items
+ */
+ items: Array;
+ /**
+ * Offset from which to retrieve items
+ */
+ offset: number;
+ /**
+ * Limit of items to get
+ */
+ limit: number;
+ /**
+ * Total number of items in result
+ */
+ total: number;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts b/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts
new file mode 100644
index 0000000000..c2f1c07fbf
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/PipelineModelField.ts
@@ -0,0 +1,20 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { BaseModelType } from './BaseModelType';
+
+/**
+ * Pipeline model field
+ */
+export type PipelineModelField = {
+ /**
+ * Name of the model
+ */
+ model_name: string;
+ /**
+ * Base model
+ */
+ base_model: BaseModelType;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts b/invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts
similarity index 52%
rename from invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts
rename to invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts
index f477c11a8d..b8cdb27acf 100644
--- a/invokeai/frontend/web/src/services/api/models/SD2ModelLoaderInvocation.ts
+++ b/invokeai/frontend/web/src/services/api/models/PipelineModelLoaderInvocation.ts
@@ -2,10 +2,12 @@
/* tslint:disable */
/* eslint-disable */
+import type { PipelineModelField } from './PipelineModelField';
+
/**
- * Loading submodels of selected model.
+ * Loads a pipeline model, outputting its submodels.
*/
-export type SD2ModelLoaderInvocation = {
+export type PipelineModelLoaderInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
@@ -14,10 +16,10 @@ export type SD2ModelLoaderInvocation = {
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
- type?: 'sd2_model_loader';
+ type?: 'pipeline_model_loader';
/**
- * Model to load
+ * The model to load
*/
- model_name?: string;
+ model: PipelineModelField;
};
diff --git a/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts b/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts
deleted file mode 100644
index 9a8a23077a..0000000000
--- a/invokeai/frontend/web/src/services/api/models/SD1ModelLoaderInvocation.ts
+++ /dev/null
@@ -1,23 +0,0 @@
-/* istanbul ignore file */
-/* tslint:disable */
-/* eslint-disable */
-
-/**
- * Loading submodels of selected model.
- */
-export type SD1ModelLoaderInvocation = {
- /**
- * The id of this node. Must be unique among all nodes.
- */
- id: string;
- /**
- * Whether or not this node is an intermediate node.
- */
- is_intermediate?: boolean;
- type?: 'sd1_model_loader';
- /**
- * Model to load
- */
- model_name?: string;
-};
-
diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts
similarity index 60%
rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts
rename to invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts
index 6bdcb87dd4..be7077cc53 100644
--- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig.ts
+++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelCheckpointConfig.ts
@@ -2,14 +2,17 @@
/* tslint:disable */
/* eslint-disable */
+import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError';
import type { ModelVariantType } from './ModelVariantType';
-export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig = {
+export type StableDiffusion1ModelCheckpointConfig = {
+ name: string;
+ base_model: BaseModelType;
+ type: 'pipeline';
path: string;
description?: string;
- format: 'checkpoint';
- default?: boolean;
+ model_format: 'checkpoint';
error?: ModelError;
vae?: string;
config?: string;
diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts
similarity index 59%
rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts
rename to invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts
index c88e042178..befe014605 100644
--- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig.ts
+++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelDiffusersConfig.ts
@@ -2,14 +2,17 @@
/* tslint:disable */
/* eslint-disable */
+import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError';
import type { ModelVariantType } from './ModelVariantType';
-export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig = {
+export type StableDiffusion1ModelDiffusersConfig = {
+ name: string;
+ base_model: BaseModelType;
+ type: 'pipeline';
path: string;
description?: string;
- format: 'diffusers';
- default?: boolean;
+ model_format: 'diffusers';
error?: ModelError;
vae?: string;
variant: ModelVariantType;
diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts
new file mode 100644
index 0000000000..01b50c2fc0
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion1ModelFormat.ts
@@ -0,0 +1,8 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * An enumeration.
+ */
+export type StableDiffusion1ModelFormat = 'checkpoint' | 'diffusers';
diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts
similarity index 69%
rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts
rename to invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts
index ec2ae4a845..dadd7cac9b 100644
--- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig.ts
+++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelCheckpointConfig.ts
@@ -2,15 +2,18 @@
/* tslint:disable */
/* eslint-disable */
+import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError';
import type { ModelVariantType } from './ModelVariantType';
import type { SchedulerPredictionType } from './SchedulerPredictionType';
-export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig = {
+export type StableDiffusion2ModelCheckpointConfig = {
+ name: string;
+ base_model: BaseModelType;
+ type: 'pipeline';
path: string;
description?: string;
- format: 'checkpoint';
- default?: boolean;
+ model_format: 'checkpoint';
error?: ModelError;
vae?: string;
config?: string;
diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts
similarity index 68%
rename from invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts
rename to invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts
index 67b897d9d9..1e4a34c5dc 100644
--- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig.ts
+++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelDiffusersConfig.ts
@@ -2,15 +2,18 @@
/* tslint:disable */
/* eslint-disable */
+import type { BaseModelType } from './BaseModelType';
import type { ModelError } from './ModelError';
import type { ModelVariantType } from './ModelVariantType';
import type { SchedulerPredictionType } from './SchedulerPredictionType';
-export type invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig = {
+export type StableDiffusion2ModelDiffusersConfig = {
+ name: string;
+ base_model: BaseModelType;
+ type: 'pipeline';
path: string;
description?: string;
- format: 'diffusers';
- default?: boolean;
+ model_format: 'diffusers';
error?: ModelError;
vae?: string;
variant: ModelVariantType;
diff --git a/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts
new file mode 100644
index 0000000000..7e7b895231
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/StableDiffusion2ModelFormat.ts
@@ -0,0 +1,8 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * An enumeration.
+ */
+export type StableDiffusion2ModelFormat = 'checkpoint' | 'diffusers';
diff --git a/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts b/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts
new file mode 100644
index 0000000000..97d6aa7ffa
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/TextualInversionModelConfig.ts
@@ -0,0 +1,17 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { BaseModelType } from './BaseModelType';
+import type { ModelError } from './ModelError';
+
+export type TextualInversionModelConfig = {
+ name: string;
+ base_model: BaseModelType;
+ type: 'embedding';
+ path: string;
+ description?: string;
+ model_format: null;
+ error?: ModelError;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts b/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts
new file mode 100644
index 0000000000..a73ee0aa32
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/VaeModelConfig.ts
@@ -0,0 +1,18 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+import type { BaseModelType } from './BaseModelType';
+import type { ModelError } from './ModelError';
+import type { VaeModelFormat } from './VaeModelFormat';
+
+export type VaeModelConfig = {
+ name: string;
+ base_model: BaseModelType;
+ type: 'vae';
+ path: string;
+ description?: string;
+ model_format: VaeModelFormat;
+ error?: ModelError;
+};
+
diff --git a/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts b/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts
new file mode 100644
index 0000000000..497f81d16f
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/models/VaeModelFormat.ts
@@ -0,0 +1,8 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+
+/**
+ * An enumeration.
+ */
+export type VaeModelFormat = 'checkpoint' | 'diffusers';
diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts
deleted file mode 100644
index f8decdb341..0000000000
--- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config.ts
+++ /dev/null
@@ -1,14 +0,0 @@
-/* istanbul ignore file */
-/* tslint:disable */
-/* eslint-disable */
-
-import type { ModelError } from './ModelError';
-
-export type invokeai__backend__model_management__models__controlnet__ControlNetModel__Config = {
- path: string;
- description?: string;
- format: ('checkpoint' | 'diffusers');
- default?: boolean;
- error?: ModelError;
-};
-
diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts
deleted file mode 100644
index 614749a2c5..0000000000
--- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__lora__LoRAModel__Config.ts
+++ /dev/null
@@ -1,14 +0,0 @@
-/* istanbul ignore file */
-/* tslint:disable */
-/* eslint-disable */
-
-import type { ModelError } from './ModelError';
-
-export type invokeai__backend__model_management__models__lora__LoRAModel__Config = {
- path: string;
- description?: string;
- format: ('lycoris' | 'diffusers');
- default?: boolean;
- error?: ModelError;
-};
-
diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts
deleted file mode 100644
index f23d5002e3..0000000000
--- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config.ts
+++ /dev/null
@@ -1,14 +0,0 @@
-/* istanbul ignore file */
-/* tslint:disable */
-/* eslint-disable */
-
-import type { ModelError } from './ModelError';
-
-export type invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config = {
- path: string;
- description?: string;
- format: null;
- default?: boolean;
- error?: ModelError;
-};
-
diff --git a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts b/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts
deleted file mode 100644
index d9314a6063..0000000000
--- a/invokeai/frontend/web/src/services/api/models/invokeai__backend__model_management__models__vae__VaeModel__Config.ts
+++ /dev/null
@@ -1,14 +0,0 @@
-/* istanbul ignore file */
-/* tslint:disable */
-/* eslint-disable */
-
-import type { ModelError } from './ModelError';
-
-export type invokeai__backend__model_management__models__vae__VaeModel__Config = {
- path: string;
- description?: string;
- format: ('checkpoint' | 'diffusers');
- default?: boolean;
- error?: ModelError;
-};
-
diff --git a/invokeai/frontend/web/src/services/api/services/BoardsService.ts b/invokeai/frontend/web/src/services/api/services/BoardsService.ts
new file mode 100644
index 0000000000..236c765cb9
--- /dev/null
+++ b/invokeai/frontend/web/src/services/api/services/BoardsService.ts
@@ -0,0 +1,247 @@
+/* istanbul ignore file */
+/* tslint:disable */
+/* eslint-disable */
+import type { BoardChanges } from '../models/BoardChanges';
+import type { BoardDTO } from '../models/BoardDTO';
+import type { Body_create_board_image } from '../models/Body_create_board_image';
+import type { Body_remove_board_image } from '../models/Body_remove_board_image';
+import type { OffsetPaginatedResults_BoardDTO_ } from '../models/OffsetPaginatedResults_BoardDTO_';
+import type { OffsetPaginatedResults_ImageDTO_ } from '../models/OffsetPaginatedResults_ImageDTO_';
+
+import type { CancelablePromise } from '../core/CancelablePromise';
+import { OpenAPI } from '../core/OpenAPI';
+import { request as __request } from '../core/request';
+
+export class BoardsService {
+
+ /**
+ * List Boards
+ * Gets a list of boards
+ * @returns any Successful Response
+ * @throws ApiError
+ */
+ public static listBoards({
+ all,
+ offset,
+ limit,
+ }: {
+ /**
+ * Whether to list all boards
+ */
+ all?: boolean,
+ /**
+ * The page offset
+ */
+ offset?: number,
+ /**
+ * The number of boards per page
+ */
+ limit?: number,
+ }): CancelablePromise<(OffsetPaginatedResults_BoardDTO_ | Array)> {
+ return __request(OpenAPI, {
+ method: 'GET',
+ url: '/api/v1/boards/',
+ query: {
+ 'all': all,
+ 'offset': offset,
+ 'limit': limit,
+ },
+ errors: {
+ 422: `Validation Error`,
+ },
+ });
+ }
+
+ /**
+ * Create Board
+ * Creates a board
+ * @returns BoardDTO The board was created successfully
+ * @throws ApiError
+ */
+ public static createBoard({
+ boardName,
+ }: {
+ /**
+ * The name of the board to create
+ */
+ boardName: string,
+ }): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'POST',
+ url: '/api/v1/boards/',
+ query: {
+ 'board_name': boardName,
+ },
+ errors: {
+ 422: `Validation Error`,
+ },
+ });
+ }
+
+ /**
+ * Get Board
+ * Gets a board
+ * @returns BoardDTO Successful Response
+ * @throws ApiError
+ */
+ public static getBoard({
+ boardId,
+ }: {
+ /**
+ * The id of board to get
+ */
+ boardId: string,
+ }): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'GET',
+ url: '/api/v1/boards/{board_id}',
+ path: {
+ 'board_id': boardId,
+ },
+ errors: {
+ 422: `Validation Error`,
+ },
+ });
+ }
+
+ /**
+ * Delete Board
+ * Deletes a board
+ * @returns any Successful Response
+ * @throws ApiError
+ */
+ public static deleteBoard({
+ boardId,
+ }: {
+ /**
+ * The id of board to delete
+ */
+ boardId: string,
+ }): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'DELETE',
+ url: '/api/v1/boards/{board_id}',
+ path: {
+ 'board_id': boardId,
+ },
+ errors: {
+ 422: `Validation Error`,
+ },
+ });
+ }
+
+ /**
+ * Update Board
+ * Updates a board
+ * @returns BoardDTO The board was updated successfully
+ * @throws ApiError
+ */
+ public static updateBoard({
+ boardId,
+ requestBody,
+ }: {
+ /**
+ * The id of board to update
+ */
+ boardId: string,
+ requestBody: BoardChanges,
+ }): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'PATCH',
+ url: '/api/v1/boards/{board_id}',
+ path: {
+ 'board_id': boardId,
+ },
+ body: requestBody,
+ mediaType: 'application/json',
+ errors: {
+ 422: `Validation Error`,
+ },
+ });
+ }
+
+ /**
+ * Create Board Image
+ * Creates a board_image
+ * @returns any The image was added to a board successfully
+ * @throws ApiError
+ */
+ public static createBoardImage({
+ requestBody,
+ }: {
+ requestBody: Body_create_board_image,
+ }): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'POST',
+ url: '/api/v1/board_images/',
+ body: requestBody,
+ mediaType: 'application/json',
+ errors: {
+ 422: `Validation Error`,
+ },
+ });
+ }
+
+ /**
+ * Remove Board Image
+ * Deletes a board_image
+ * @returns any The image was removed from the board successfully
+ * @throws ApiError
+ */
+ public static removeBoardImage({
+ requestBody,
+ }: {
+ requestBody: Body_remove_board_image,
+ }): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'DELETE',
+ url: '/api/v1/board_images/',
+ body: requestBody,
+ mediaType: 'application/json',
+ errors: {
+ 422: `Validation Error`,
+ },
+ });
+ }
+
+ /**
+ * List Board Images
+ * Gets a list of images for a board
+ * @returns OffsetPaginatedResults_ImageDTO_ Successful Response
+ * @throws ApiError
+ */
+ public static listBoardImages({
+ boardId,
+ offset,
+ limit = 10,
+ }: {
+ /**
+ * The id of the board
+ */
+ boardId: string,
+ /**
+ * The page offset
+ */
+ offset?: number,
+ /**
+ * The number of boards per page
+ */
+ limit?: number,
+ }): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'GET',
+ url: '/api/v1/board_images/{board_id}',
+ path: {
+ 'board_id': boardId,
+ },
+ query: {
+ 'offset': offset,
+ 'limit': limit,
+ },
+ errors: {
+ 422: `Validation Error`,
+ },
+ });
+ }
+
+}
diff --git a/invokeai/frontend/web/src/services/api/services/ImagesService.ts b/invokeai/frontend/web/src/services/api/services/ImagesService.ts
index d0eae92d4b..bfdef887a0 100644
--- a/invokeai/frontend/web/src/services/api/services/ImagesService.ts
+++ b/invokeai/frontend/web/src/services/api/services/ImagesService.ts
@@ -25,6 +25,7 @@ export class ImagesService {
imageOrigin,
categories,
isIntermediate,
+ boardId,
offset,
limit = 10,
}: {
@@ -40,6 +41,10 @@ export class ImagesService {
* Whether to list intermediate images
*/
isIntermediate?: boolean,
+ /**
+ * The board id to filter by
+ */
+ boardId?: string,
/**
* The page offset
*/
@@ -56,6 +61,7 @@ export class ImagesService {
'image_origin': imageOrigin,
'categories': categories,
'is_intermediate': isIntermediate,
+ 'board_id': boardId,
'offset': offset,
'limit': limit,
},
diff --git a/invokeai/frontend/web/src/services/api/services/SessionsService.ts b/invokeai/frontend/web/src/services/api/services/SessionsService.ts
index 2e4a83b25f..51a36caad1 100644
--- a/invokeai/frontend/web/src/services/api/services/SessionsService.ts
+++ b/invokeai/frontend/web/src/services/api/services/SessionsService.ts
@@ -51,6 +51,7 @@ import type { PaginatedResults_GraphExecutionState_ } from '../models/PaginatedR
import type { ParamFloatInvocation } from '../models/ParamFloatInvocation';
import type { ParamIntInvocation } from '../models/ParamIntInvocation';
import type { PidiImageProcessorInvocation } from '../models/PidiImageProcessorInvocation';
+import type { PipelineModelLoaderInvocation } from '../models/PipelineModelLoaderInvocation';
import type { RandomIntInvocation } from '../models/RandomIntInvocation';
import type { RandomRangeInvocation } from '../models/RandomRangeInvocation';
import type { RangeInvocation } from '../models/RangeInvocation';
@@ -58,8 +59,6 @@ import type { RangeOfSizeInvocation } from '../models/RangeOfSizeInvocation';
import type { ResizeLatentsInvocation } from '../models/ResizeLatentsInvocation';
import type { RestoreFaceInvocation } from '../models/RestoreFaceInvocation';
import type { ScaleLatentsInvocation } from '../models/ScaleLatentsInvocation';
-import type { SD1ModelLoaderInvocation } from '../models/SD1ModelLoaderInvocation';
-import type { SD2ModelLoaderInvocation } from '../models/SD2ModelLoaderInvocation';
import type { ShowImageInvocation } from '../models/ShowImageInvocation';
import type { StepParamEasingInvocation } from '../models/StepParamEasingInvocation';
import type { SubtractInvocation } from '../models/SubtractInvocation';
@@ -175,7 +174,7 @@ export class SessionsService {
* The id of the session
*/
sessionId: string,
- requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
+ requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
}): CancelablePromise {
return __request(OpenAPI, {
method: 'POST',
@@ -212,7 +211,7 @@ export class SessionsService {
* The path to the node in the graph
*/
nodePath: string,
- requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | SD1ModelLoaderInvocation | SD2ModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
+ requestBody: (LoadImageInvocation | ShowImageInvocation | ImageCropInvocation | ImagePasteInvocation | MaskFromAlphaInvocation | ImageMultiplyInvocation | ImageChannelInvocation | ImageConvertInvocation | ImageBlurInvocation | ImageResizeInvocation | ImageScaleInvocation | ImageLerpInvocation | ImageInverseLerpInvocation | ControlNetInvocation | ImageProcessorInvocation | PipelineModelLoaderInvocation | LoraLoaderInvocation | DynamicPromptInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | ParamFloatInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | FloatLinearRangeInvocation | StepParamEasingInvocation | UpscaleInvocation | RestoreFaceInvocation | InpaintInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | CannyImageProcessorInvocation | HedImageProcessorInvocation | LineartImageProcessorInvocation | LineartAnimeImageProcessorInvocation | OpenposeImageProcessorInvocation | MidasDepthImageProcessorInvocation | NormalbaeImageProcessorInvocation | MlsdImageProcessorInvocation | PidiImageProcessorInvocation | ContentShuffleImageProcessorInvocation | ZoeDepthImageProcessorInvocation | MediapipeFaceProcessorInvocation | LatentsToLatentsInvocation),
}): CancelablePromise {
return __request(OpenAPI, {
method: 'PUT',
diff --git a/invokeai/frontend/web/src/services/apiSlice.ts b/invokeai/frontend/web/src/services/apiSlice.ts
new file mode 100644
index 0000000000..e2d765dd90
--- /dev/null
+++ b/invokeai/frontend/web/src/services/apiSlice.ts
@@ -0,0 +1,223 @@
+import {
+ TagDescription,
+ createApi,
+ fetchBaseQuery,
+} from '@reduxjs/toolkit/query/react';
+import { BoardDTO } from './api/models/BoardDTO';
+import { OffsetPaginatedResults_BoardDTO_ } from './api/models/OffsetPaginatedResults_BoardDTO_';
+import { BoardChanges } from './api/models/BoardChanges';
+import { OffsetPaginatedResults_ImageDTO_ } from './api/models/OffsetPaginatedResults_ImageDTO_';
+import { ImageDTO } from './api/models/ImageDTO';
+import {
+ FullTagDescription,
+ TagTypesFrom,
+ TagTypesFromApi,
+} from '@reduxjs/toolkit/dist/query/endpointDefinitions';
+import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
+import { BaseModelType } from './api/models/BaseModelType';
+import { ModelType } from './api/models/ModelType';
+import { ModelsList } from './api/models/ModelsList';
+import { keyBy } from 'lodash-es';
+
+type ListBoardsArg = { offset: number; limit: number };
+type UpdateBoardArg = { board_id: string; changes: BoardChanges };
+type AddImageToBoardArg = { board_id: string; image_name: string };
+type RemoveImageFromBoardArg = { board_id: string; image_name: string };
+type ListBoardImagesArg = { board_id: string; offset: number; limit: number };
+type ListModelsArg = { base_model?: BaseModelType; model_type?: ModelType };
+
+type ModelConfig = ModelsList['models'][number];
+
+const tagTypes = ['Board', 'Image', 'Model'];
+type ApiFullTagDescription = FullTagDescription<(typeof tagTypes)[number]>;
+
+const LIST = 'LIST';
+
+const modelsAdapter = createEntityAdapter({
+ selectId: (model) => getModelId(model),
+ sortComparer: (a, b) => a.name.localeCompare(b.name),
+});
+
+const getModelId = ({ base_model, type, name }: ModelConfig) =>
+ `${base_model}/${type}/${name}`;
+
+export const api = createApi({
+ baseQuery: fetchBaseQuery({ baseUrl: 'http://localhost:5173/api/v1/' }),
+ reducerPath: 'api',
+ tagTypes,
+ endpoints: (build) => ({
+ /**
+ * Models Queries
+ */
+
+ listModels: build.query, ListModelsArg>({
+ query: (arg) => ({ url: 'models/', params: arg }),
+ providesTags: (result, error, arg) => {
+ // any list of boards
+ const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST }];
+
+ if (result) {
+ // and individual tags for each board
+ tags.push(
+ ...result.ids.map((id) => ({
+ type: 'Model' as const,
+ id,
+ }))
+ );
+ }
+
+ return tags;
+ },
+ transformResponse: (response: ModelsList, meta, arg) => {
+ return modelsAdapter.addMany(
+ modelsAdapter.getInitialState(),
+ keyBy(response.models, getModelId)
+ );
+ },
+ }),
+ /**
+ * Boards Queries
+ */
+ listBoards: build.query({
+ query: (arg) => ({ url: 'boards/', params: arg }),
+ providesTags: (result, error, arg) => {
+ // any list of boards
+ const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }];
+
+ if (result) {
+ // and individual tags for each board
+ tags.push(
+ ...result.items.map(({ board_id }) => ({
+ type: 'Board' as const,
+ id: board_id,
+ }))
+ );
+ }
+
+ return tags;
+ },
+ }),
+
+ listAllBoards: build.query, void>({
+ query: () => ({
+ url: 'boards/',
+ params: { all: true },
+ }),
+ providesTags: (result, error, arg) => {
+ // any list of boards
+ const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST }];
+
+ if (result) {
+ // and individual tags for each board
+ tags.push(
+ ...result.map(({ board_id }) => ({
+ type: 'Board' as const,
+ id: board_id,
+ }))
+ );
+ }
+
+ return tags;
+ },
+ }),
+
+ /**
+ * Boards Mutations
+ */
+
+ createBoard: build.mutation({
+ query: (board_name) => ({
+ url: `boards/`,
+ method: 'POST',
+ params: { board_name },
+ }),
+ invalidatesTags: [{ id: 'Board', type: LIST }],
+ }),
+
+ updateBoard: build.mutation({
+ query: ({ board_id, changes }) => ({
+ url: `boards/${board_id}`,
+ method: 'PATCH',
+ body: changes,
+ }),
+ invalidatesTags: (result, error, arg) => [
+ { type: 'Board', id: arg.board_id },
+ ],
+ }),
+
+ deleteBoard: build.mutation({
+ query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }),
+ invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }],
+ }),
+
+ /**
+ * Board Images Queries
+ */
+
+ listBoardImages: build.query<
+ OffsetPaginatedResults_ImageDTO_,
+ ListBoardImagesArg
+ >({
+ query: ({ board_id, offset, limit }) => ({
+ url: `board_images/${board_id}`,
+ method: 'DELETE',
+ body: { offset, limit },
+ }),
+ }),
+
+ /**
+ * Board Images Mutations
+ */
+
+ addImageToBoard: build.mutation({
+ query: ({ board_id, image_name }) => ({
+ url: `board_images/`,
+ method: 'POST',
+ body: { board_id, image_name },
+ }),
+ invalidatesTags: (result, error, arg) => [
+ { type: 'Board', id: arg.board_id },
+ { type: 'Image', id: arg.image_name },
+ ],
+ }),
+
+ removeImageFromBoard: build.mutation({
+ query: ({ board_id, image_name }) => ({
+ url: `board_images/`,
+ method: 'DELETE',
+ body: { board_id, image_name },
+ }),
+ invalidatesTags: (result, error, arg) => [
+ { type: 'Board', id: arg.board_id },
+ { type: 'Image', id: arg.image_name },
+ ],
+ }),
+
+ /**
+ * Image Queries
+ */
+ getImageDTO: build.query({
+ query: (image_name) => ({ url: `images/${image_name}/metadata` }),
+ providesTags: (result, error, arg) => {
+ const tags: ApiFullTagDescription[] = [{ type: 'Image', id: arg }];
+ if (result?.board_id) {
+ tags.push({ type: 'Board', id: result.board_id });
+ }
+ return tags;
+ },
+ }),
+ }),
+});
+
+export const {
+ useListBoardsQuery,
+ useListAllBoardsQuery,
+ useCreateBoardMutation,
+ useUpdateBoardMutation,
+ useDeleteBoardMutation,
+ useAddImageToBoardMutation,
+ useRemoveImageFromBoardMutation,
+ useListBoardImagesQuery,
+ useGetImageDTOQuery,
+ useListModelsQuery,
+} = api;
diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts
index 5832cb24b1..ed154b9cd8 100644
--- a/invokeai/frontend/web/src/services/events/actions.ts
+++ b/invokeai/frontend/web/src/services/events/actions.ts
@@ -53,14 +53,14 @@ export const appSocketDisconnected = createAction(
* Do not use. Only for use in middleware.
*/
export const socketSubscribed = createAction<
- BaseSocketPayload & { sessionId: string }
+ BaseSocketPayload & { sessionId: string; boardId: string | undefined }
>('socket/socketSubscribed');
/**
* App-level Socket.IO Subscribed
*/
export const appSocketSubscribed = createAction<
- BaseSocketPayload & { sessionId: string }
+ BaseSocketPayload & { sessionId: string; boardId: string | undefined }
>('socket/appSocketSubscribed');
/**
diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts
index f1eb844f2c..5b427b1690 100644
--- a/invokeai/frontend/web/src/services/events/middleware.ts
+++ b/invokeai/frontend/web/src/services/events/middleware.ts
@@ -85,6 +85,7 @@ export const socketMiddleware = () => {
socketSubscribed({
sessionId: sessionId,
timestamp: getTimestamp(),
+ boardId: getState().boards.selectedBoardId,
})
);
}
diff --git a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
index 2c4cba510a..62b5864185 100644
--- a/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
+++ b/invokeai/frontend/web/src/services/events/util/setEventListeners.ts
@@ -44,6 +44,7 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
socketSubscribed({
sessionId,
timestamp: getTimestamp(),
+ boardId: getState().boards.selectedBoardId,
})
);
}
diff --git a/invokeai/frontend/web/src/services/thunks/image.ts b/invokeai/frontend/web/src/services/thunks/image.ts
index a0725bf235..fe198cf6f9 100644
--- a/invokeai/frontend/web/src/services/thunks/image.ts
+++ b/invokeai/frontend/web/src/services/thunks/image.ts
@@ -1,5 +1,6 @@
import { createAppAsyncThunk } from 'app/store/storeUtils';
import { selectImagesAll } from 'features/gallery/store/imagesSlice';
+import { size } from 'lodash-es';
import { ImagesService } from 'services/api';
type imageUrlsReceivedArg = Parameters<
@@ -121,25 +122,61 @@ type ImagesListedArg = Parameters<
export const IMAGES_PER_PAGE = 20;
+const DEFAULT_IMAGES_LISTED_ARG = {
+ isIntermediate: false,
+ limit: IMAGES_PER_PAGE,
+};
+
/**
* `ImagesService.listImagesWithMetadata()` thunk
*/
export const receivedPageOfImages = createAppAsyncThunk(
'api/receivedPageOfImages',
- async (_, { getState }) => {
+ async (arg: ImagesListedArg, { getState }) => {
const state = getState();
const { categories } = state.images;
+ const { selectedBoardId } = state.boards;
- const totalImagesInFilter = selectImagesAll(state).filter((i) =>
- categories.includes(i.image_category)
- ).length;
-
- const response = await ImagesService.listImagesWithMetadata({
- categories,
- isIntermediate: false,
- offset: totalImagesInFilter,
- limit: IMAGES_PER_PAGE,
+ const images = selectImagesAll(state).filter((i) => {
+ const isInCategory = categories.includes(i.image_category);
+ const isInSelectedBoard = selectedBoardId
+ ? i.board_id === selectedBoardId
+ : true;
+ return isInCategory && isInSelectedBoard;
});
+
+ let queryArg: ReceivedImagesArg = {};
+
+ if (size(arg)) {
+ queryArg = {
+ ...DEFAULT_IMAGES_LISTED_ARG,
+ offset: images.length,
+ ...arg,
+ };
+ } else {
+ queryArg = {
+ ...DEFAULT_IMAGES_LISTED_ARG,
+ categories,
+ offset: images.length,
+ };
+ }
+
+ const response = await ImagesService.listImagesWithMetadata(queryArg);
+ return response;
+ }
+);
+
+type ReceivedImagesArg = Parameters<
+ (typeof ImagesService)['listImagesWithMetadata']
+>[0];
+
+/**
+ * `ImagesService.listImagesWithMetadata()` thunk
+ */
+export const receivedImages = createAppAsyncThunk(
+ 'api/receivedImages',
+ async (arg: ReceivedImagesArg, { getState }) => {
+ const response = await ImagesService.listImagesWithMetadata(arg);
return response;
}
);
diff --git a/invokeai/frontend/web/src/services/thunks/model.ts b/invokeai/frontend/web/src/services/thunks/model.ts
deleted file mode 100644
index 97f2bd8016..0000000000
--- a/invokeai/frontend/web/src/services/thunks/model.ts
+++ /dev/null
@@ -1,33 +0,0 @@
-import { log } from 'app/logging/useLogger';
-import { createAppAsyncThunk } from 'app/store/storeUtils';
-import { Model } from 'features/system/store/modelSlice';
-import { reduce, size } from 'lodash-es';
-import { ModelsService } from 'services/api';
-
-const models = log.child({ namespace: 'model' });
-
-export const IMAGES_PER_PAGE = 20;
-
-export const receivedModels = createAppAsyncThunk(
- 'models/receivedModels',
- async (_) => {
- const response = await ModelsService.listModels();
-
- const deserializedModels = reduce(
- response.models['sd-1']['pipeline'],
- (modelsAccumulator, model, modelName) => {
- modelsAccumulator[modelName] = { ...model, name: modelName };
-
- return modelsAccumulator;
- },
- {} as Record
- );
-
- models.info(
- { response },
- `Received ${size(response.models['sd-1']['pipeline'])} models`
- );
-
- return deserializedModels;
- }
-);
diff --git a/invokeai/frontend/web/src/services/types/guards.ts b/invokeai/frontend/web/src/services/types/guards.ts
index 334c04e6ed..7ac0d95e6a 100644
--- a/invokeai/frontend/web/src/services/types/guards.ts
+++ b/invokeai/frontend/web/src/services/types/guards.ts
@@ -11,6 +11,7 @@ import {
LatentsOutput,
ResourceOrigin,
ImageDTO,
+ BoardDTO,
} from 'services/api';
export const isImageDTO = (obj: unknown): obj is ImageDTO => {
@@ -29,6 +30,16 @@ export const isImageDTO = (obj: unknown): obj is ImageDTO => {
);
};
+export const isBoardDTO = (obj: unknown): obj is BoardDTO => {
+ return (
+ isObject(obj) &&
+ 'board_id' in obj &&
+ isString(obj?.board_id) &&
+ 'board_name' in obj &&
+ isString(obj?.board_name)
+ );
+};
+
export const isImageOutput = (
output: GraphExecutionState['results'][string]
): output is ImageOutput => output.type === 'image_output';
diff --git a/invokeai/frontend/web/src/theme/colors/greenTea.ts b/invokeai/frontend/web/src/theme/colors/greenTea.ts
index ffecbf2ffa..318aecbc61 100644
--- a/invokeai/frontend/web/src/theme/colors/greenTea.ts
+++ b/invokeai/frontend/web/src/theme/colors/greenTea.ts
@@ -4,8 +4,8 @@ import { generateColorPalette } from '../util/generateColorPalette';
export const greenTeaThemeColors: InvokeAIThemeColors = {
base: generateColorPalette(223, 10),
baseAlpha: generateColorPalette(223, 10, false, true),
- accent: generateColorPalette(155, 80),
- accentAlpha: generateColorPalette(155, 80, false, true),
+ accent: generateColorPalette(160, 60),
+ accentAlpha: generateColorPalette(160, 60, false, true),
working: generateColorPalette(47, 68),
workingAlpha: generateColorPalette(47, 68, false, true),
warning: generateColorPalette(28, 75),
@@ -14,5 +14,5 @@ export const greenTeaThemeColors: InvokeAIThemeColors = {
okAlpha: generateColorPalette(122, 49, false, true),
error: generateColorPalette(0, 50),
errorAlpha: generateColorPalette(0, 50, false, true),
- gridLineColor: 'rgba(255, 255, 255, 0.2)',
+ gridLineColor: 'rgba(255, 255, 255, 0.15)',
};
diff --git a/invokeai/frontend/web/src/theme/colors/invokeAI.ts b/invokeai/frontend/web/src/theme/colors/invokeAI.ts
index c39b3bed81..82db58bd35 100644
--- a/invokeai/frontend/web/src/theme/colors/invokeAI.ts
+++ b/invokeai/frontend/web/src/theme/colors/invokeAI.ts
@@ -2,8 +2,8 @@ import { InvokeAIThemeColors } from 'theme/themeTypes';
import { generateColorPalette } from 'theme/util/generateColorPalette';
export const invokeAIThemeColors: InvokeAIThemeColors = {
- base: generateColorPalette(225, 15),
- baseAlpha: generateColorPalette(225, 15, false, true),
+ base: generateColorPalette(220, 15),
+ baseAlpha: generateColorPalette(220, 15, false, true),
accent: generateColorPalette(250, 50),
accentAlpha: generateColorPalette(250, 50, false, true),
working: generateColorPalette(47, 67),
@@ -14,5 +14,5 @@ export const invokeAIThemeColors: InvokeAIThemeColors = {
okAlpha: generateColorPalette(113, 70, false, true),
error: generateColorPalette(0, 76),
errorAlpha: generateColorPalette(0, 76, false, true),
- gridLineColor: 'rgba(255, 255, 255, 0.2)',
+ gridLineColor: 'rgba(150, 150, 180, 0.15)',
};
diff --git a/invokeai/frontend/web/src/theme/colors/lightTheme.ts b/invokeai/frontend/web/src/theme/colors/lightTheme.ts
index 2a7a05bbd2..2fdbd1a769 100644
--- a/invokeai/frontend/web/src/theme/colors/lightTheme.ts
+++ b/invokeai/frontend/web/src/theme/colors/lightTheme.ts
@@ -14,5 +14,5 @@ export const lightThemeColors: InvokeAIThemeColors = {
okAlpha: generateColorPalette(122, 49, true, true),
error: generateColorPalette(0, 50, true),
errorAlpha: generateColorPalette(0, 50, true, true),
- gridLineColor: 'rgba(0, 0, 0, 0.2)',
+ gridLineColor: 'rgba(0, 0, 0, 0.15)',
};
diff --git a/invokeai/frontend/web/src/theme/colors/oceanBlue.ts b/invokeai/frontend/web/src/theme/colors/oceanBlue.ts
index adfb8ab288..952e0a5066 100644
--- a/invokeai/frontend/web/src/theme/colors/oceanBlue.ts
+++ b/invokeai/frontend/web/src/theme/colors/oceanBlue.ts
@@ -14,5 +14,5 @@ export const oceanBlueColors: InvokeAIThemeColors = {
okAlpha: generateColorPalette(122, 49, false, true),
error: generateColorPalette(0, 100),
errorAlpha: generateColorPalette(0, 100, false, true),
- gridLineColor: 'rgba(136, 148, 184, 0.2)',
+ gridLineColor: 'rgba(136, 148, 184, 0.15)',
};