Merge branch 'main' into install-script-python-version-error-prompt-fix

This commit is contained in:
Lincoln Stein 2023-06-23 02:15:36 +01:00 committed by GitHub
commit df1907e849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
137 changed files with 4087 additions and 1018 deletions

View File

@ -2,8 +2,17 @@
from logging import Logger from logging import Logger
import os import os
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
BoardImagesService,
BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
@ -57,7 +66,7 @@ class ApiDependencies:
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = config.db_path db_location = config.db_path
db_location.parent.mkdir(parents=True,exist_ok=True) db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState]( graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
@ -72,14 +81,40 @@ class ApiDependencies:
DiskLatentsStorage(f"{output_folder}/latents") DiskLatentsStorage(f"{output_folder}/latents")
) )
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
boards = BoardService(
services=BoardServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
board_images = BoardImagesService(
services=BoardImagesServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
images = ImageService( images = ImageService(
image_record_storage=image_record_storage, services=ImageServiceDependencies(
image_file_storage=image_file_storage, board_image_record_storage=board_image_record_storage,
metadata=metadata, image_record_storage=image_record_storage,
url=urls, image_file_storage=image_file_storage,
logger=logger, metadata=metadata,
names=names, url=urls,
graph_execution_manager=graph_execution_manager, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
) )
services = InvocationServices( services = InvocationServices(
@ -87,6 +122,8 @@ class ApiDependencies:
events=events, events=events,
latents=latents, latents=latents,
images=images, images=images,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs" filename=db_location, table_name="graphs"

View File

@ -0,0 +1,69 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO
from ..dependencies import ApiDependencies
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
@board_images_router.post(
"/",
operation_id="create_board_image",
responses={
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
)
async def create_board_image(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
):
"""Creates a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add to board")
@board_images_router.delete(
"/",
operation_id="remove_board_image",
responses={
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
)
async def remove_board_image(
board_id: str = Body(description="The id of the board"),
image_name: str = Body(description="The name of the image to remove"),
):
"""Deletes a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
@board_images_router.get(
"/{board_id}",
operation_id="list_board_images",
response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_board_images(
board_id: str = Path(description="The id of the board"),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of boards per page"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images for a board"""
results = ApiDependencies.invoker.services.board_images.get_images_for_board(
board_id,
)
return results

View File

@ -0,0 +1,108 @@
from typing import Optional, Union
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.board_record_storage import BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from ..dependencies import ApiDependencies
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@boards_router.post(
"/",
operation_id="create_board",
responses={
201: {"description": "The board was created successfully"},
},
status_code=201,
response_model=BoardDTO,
)
async def create_board(
board_name: str = Query(description="The name of the board to create"),
) -> BoardDTO:
"""Creates a board"""
try:
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to create board")
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
async def get_board(
board_id: str = Path(description="The id of board to get"),
) -> BoardDTO:
"""Gets a board"""
try:
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
return result
except Exception as e:
raise HTTPException(status_code=404, detail="Board not found")
@boards_router.patch(
"/{board_id}",
operation_id="update_board",
responses={
201: {
"description": "The board was updated successfully",
},
},
status_code=201,
response_model=BoardDTO,
)
async def update_board(
board_id: str = Path(description="The id of board to update"),
changes: BoardChanges = Body(description="The changes to apply to the board"),
) -> BoardDTO:
"""Updates a board"""
try:
result = ApiDependencies.invoker.services.boards.update(
board_id=board_id, changes=changes
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
@boards_router.delete("/{board_id}", operation_id="delete_board")
async def delete_board(
board_id: str = Path(description="The id of board to delete"),
) -> None:
"""Deletes a board"""
try:
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@boards_router.get(
"/",
operation_id="list_boards",
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
)
async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(
default=None, description="The number of boards per page"
),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all()
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
else:
raise HTTPException(
status_code=400,
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
)

View File

@ -221,6 +221,9 @@ async def list_images_with_metadata(
is_intermediate: Optional[bool] = Query( is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images" default=None, description="Whether to list intermediate images"
), ),
board_id: Optional[str] = Query(
default=None, description="The board id to filter by"
),
offset: int = Query(default=0, description="The page offset"), offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"), limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
@ -232,6 +235,7 @@ async def list_images_with_metadata(
image_origin, image_origin,
categories, categories,
is_intermediate, is_intermediate,
board_id,
) )
return image_dtos return image_dtos

View File

@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import get_all_model_configs from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(get_all_model_configs())] MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend models: list[MODEL_CONFIGS]
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
@models_router.get( @models_router.get(
@ -72,10 +71,10 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models( async def list_models(
base_model: BaseModelType = Query( base_model: Optional[BaseModelType] = Query(
default=None, description="Base model" default=None, description="Base model"
), ),
model_type: ModelType = Query( model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get" default=None, description="The type of model to get"
), ),
) -> ModelsList: ) -> ModelsList:

View File

@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images from .api.routers import sessions, models, images, boards, board_images
from .api.sockets import SocketIO from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api") app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
# Build a custom OpenAPI to include all outputs # Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow? # TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi(): def custom_openapi():
@ -116,6 +120,22 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref invoker_schema["output"] = outputs_ref
from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__
if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)
app.openapi_schema = openapi_schema app.openapi_schema = openapi_schema
return app.openapi_schema return app.openapi_schema

View File

@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on #fmt: on
class SD1ModelLoaderInvocation(BaseInvocation): class PipelineModelField(BaseModel):
"""Loading submodels of selected model.""" """Pipeline model field"""
type: Literal["sd1_model_loader"] = "sd1_model_loader" model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_name: str = Field(default="", description="Model to load")
class PipelineModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
model: PipelineModelField = Field(description="The model to load")
# TODO: precision? # TODO: precision?
# Schema customisation # Schema customisation
@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"tags": ["model", "loader"], "tags": ["model", "loader"],
"type_hints": { "type_hints": {
"model_name": "model" # TODO: rename to model_name? "model": "model"
} }
}, },
} }
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion1 # TODO: base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Pipeline
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
): ):
raise Exception(f"Unkown model name: {self.model_name}!") raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
""" """
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation):
return ModelLoaderOutput( return ModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.UNet, submodel=SubModelType.UNet,
), ),
scheduler=ModelInfo( scheduler=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.Scheduler, submodel=SubModelType.Scheduler,
), ),
loras=[], loras=[],
), ),
clip=ClipField( clip=ClipField(
tokenizer=ModelInfo( tokenizer=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.Tokenizer, submodel=SubModelType.Tokenizer,
), ),
text_encoder=ModelInfo( text_encoder=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.TextEncoder, submodel=SubModelType.TextEncoder,
), ),
loras=[], loras=[],
), ),
vae=VaeField( vae=VaeField(
vae=ModelInfo( vae=ModelInfo(
model_name=self.model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=ModelType.Pipeline, model_type=model_type,
submodel=SubModelType.Vae,
),
)
)
# TODO: optimize(less code copy)
class SD2ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd2_model_loader"] = "sd2_model_loader"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion2 # TODO:
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
):
raise Exception(f"Unkown model name: {self.model_name}!")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Vae, submodel=SubModelType.Vae,
), ),
) )

View File

@ -0,0 +1,254 @@
from abc import ABC, abstractmethod
import sqlite3
import threading
from typing import Union, cast
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageRecord,
deserialize_image_record,
)
class BoardImageRecordStorageBase(ABC):
"""Abstract base class for the one-to-many board-image relationship record storage."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
pass
@abstractmethod
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
@abstractmethod
def get_image_count_for_board(
self,
board_id: str,
) -> int:
"""Gets the number of images for a board."""
pass
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for board id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
"""
)
# Add index for board id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE board_id = ? AND image_name = ?;
""",
(board_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_images_for_board(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_image_count_for_board(self, board_id: str) -> int:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
""",
(board_id,),
)
count = cast(int, self._cursor.fetchone()[0])
return count
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -0,0 +1,142 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import List, Union
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import (
BoardRecord,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
from invokeai.app.services.urls import UrlServiceBase
class BoardImagesServiceABC(ABC):
"""High-level service for board-image relationship management."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images for a board."""
pass
@abstractmethod
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
class BoardImagesServiceDependencies:
"""Service dependencies for the BoardImagesService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardImagesService(BoardImagesServiceABC):
_services: BoardImagesServiceDependencies
def __init__(self, services: BoardImagesServiceDependencies):
self._services = services
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.add_image_to_board(board_id, image_name)
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.remove_image_from_board(board_id, image_name)
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
image_records = self._services.board_image_records.get_images_for_board(
board_id
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
board_id,
),
image_records.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=image_records.offset,
limit=image_records.limit,
total=image_records.total,
)
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
board_id = self._services.board_image_records.get_board_for_image(image_name)
return board_id
def board_record_to_dto(
board_record: BoardRecord, cover_image_name: str | None, image_count: int
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.dict(exclude={'cover_image_name'}),
cover_image_name=cover_image_name,
image_count=image_count,
)

View File

@ -0,0 +1,329 @@
from abc import ABC, abstractmethod
from typing import Optional, cast
import sqlite3
import threading
from typing import Optional, Union
import uuid
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import (
BoardRecord,
deserialize_board_record,
)
from pydantic import BaseModel, Field, Extra
class BoardChanges(BaseModel, extra=Extra.forbid):
board_name: Optional[str] = Field(description="The board's new name.")
cover_image_name: Optional[str] = Field(
description="The name of the board's new cover image."
)
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
super().__init__(message)
class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
super().__init__(message)
class BoardRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store."""
@abstractmethod
def delete(self, board_id: str) -> None:
"""Deletes a board record."""
pass
@abstractmethod
def save(
self,
board_name: str,
) -> BoardRecord:
"""Saves a board record."""
pass
@abstractmethod
def get(
self,
board_id: str,
) -> BoardRecord:
"""Gets a board record."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
"""Updates a board record."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardRecord]:
"""Gets all board records."""
pass
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = str(uuid.uuid4())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
try:
self._lock.acquire()
# Change the name of a board
if changes.board_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](
items=boards, offset=offset, limit=limit, total=count
)
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_all(
self,
) -> list[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY created_at DESC
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
return boards
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -0,0 +1,185 @@
from abc import ABC, abstractmethod
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import board_record_to_dto
from invokeai.app.services.board_record_storage import (
BoardChanges,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.urls import UrlServiceBase
class BoardServiceABC(ABC):
"""High-level service for board management."""
@abstractmethod
def create(
self,
board_name: str,
) -> BoardDTO:
"""Creates a board."""
pass
@abstractmethod
def get_dto(
self,
board_id: str,
) -> BoardDTO:
"""Gets a board."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
"""Updates a board."""
pass
@abstractmethod
def delete(
self,
board_id: str,
) -> None:
"""Deletes a board."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardDTO]:
"""Gets all boards."""
pass
class BoardServiceDependencies:
"""Service dependencies for the BoardService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardService(BoardServiceABC):
_services: BoardServiceDependencies
def __init__(self, services: BoardServiceDependencies):
self._services = services
def create(
self,
board_name: str,
) -> BoardDTO:
board_record = self._services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0)
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
board_record = self._services.board_records.update(board_id, changes)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None:
self._services.board_records.delete(board_id)
def get_many(
self, offset: int = 0, limit: int = 10
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_records.get_many(offset, limit)
board_dtos = []
for r in board_records.items:
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO](
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
)
def get_all(self) -> list[BoardDTO]:
board_records = self._services.board_records.get_all()
board_dtos = []
for r in board_records:
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos

View File

@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC):
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]: ) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records.""" """Gets a page of image records."""
pass pass
@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC):
"""Saves an image record.""" """Saves an image record."""
pass pass
@abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
"""Gets the most recent image for a board."""
pass
class SqliteImageRecordStorage(ImageRecordStorageBase): class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str _filename: str
@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._lock.release() self._lock.release()
def _create_tables(self) -> None: def _create_tables(self) -> None:
"""Creates the tables for the `images` database.""" """Creates the `images` table."""
# Create the `images` table. # Create the `images` table.
self._cursor.execute( self._cursor.execute(
@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id TEXT, node_id TEXT,
metadata TEXT, metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE, is_intermediate BOOLEAN DEFAULT FALSE,
board_id TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger -- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@ -190,7 +197,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
AFTER UPDATE AFTER UPDATE
ON images FOR EACH ROW ON images FOR EACH ROW
BEGIN BEGIN
UPDATE images SET updated_at = current_timestamp UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name; WHERE image_name = old.image_name;
END; END;
""" """
@ -259,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""", """,
(changes.is_intermediate, image_name), (changes.is_intermediate, image_name),
) )
self._conn.commit() self._conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
@ -273,38 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]: ) -> OffsetPaginatedResults[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
# Manually build two queries - one for the count, one for the records # Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n""" images_query = """--sql
images_query = f"""SELECT * FROM images WHERE 1=1\n""" SELECT images.*
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = "" query_conditions = ""
query_params = [] query_params = []
if image_origin is not None: if image_origin is not None:
query_conditions += f"""AND image_origin = ?\n""" query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value) query_params.append(image_origin.value)
if categories is not None: if categories is not None:
## Convert the enum values to unique list of strings # Convert the enum values to unique list of strings
category_strings = list(map(lambda c: c.value, set(categories))) category_strings = list(map(lambda c: c.value, set(categories)))
# Create the correct length of placeholders # Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings)) placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params # Unpack the included categories into the query params
for c in category_strings: for c in category_strings:
query_params.append(c) query_params.append(c)
if is_intermediate is not None: if is_intermediate is not None:
query_conditions += f"""AND is_intermediate = ?\n""" query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate) query_params.append(is_intermediate)
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n""" if board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
query_pagination = """--sql
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
"""
# Final images query with pagination # Final images query with pagination
images_query += query_conditions + query_pagination + ";" images_query += query_conditions + query_pagination + ";"
@ -321,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
count_query += query_conditions + ";" count_query += query_conditions + ";"
count_params = query_params.copy() count_params = query_params.copy()
self._cursor.execute(count_query, count_params) self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0] count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise e raise e
@ -412,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
raise ImageRecordSaveException from e raise ImageRecordSaveException from e
finally: finally:
self._lock.release() self._lock.release()
def get_most_recent_image_for_board(
self, board_id: str
) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
ORDER BY images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
finally:
self._lock.release()
if result is None:
return None
return deserialize_image_record(dict(result))

View File

@ -10,6 +10,7 @@ from invokeai.app.models.image import (
InvalidOriginException, InvalidOriginException,
) )
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_record_storage import ( from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException, ImageRecordDeleteException,
ImageRecordNotFoundException, ImageRecordNotFoundException,
@ -49,7 +50,7 @@ class ImageServiceABC(ABC):
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
intermediate: bool = False, is_intermediate: bool = False,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@ -79,7 +80,7 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_path(self, image_name: str) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path.""" """Gets an image's path."""
pass pass
@ -101,6 +102,7 @@ class ImageServiceABC(ABC):
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs.""" """Gets a paginated list of image DTOs."""
pass pass
@ -114,8 +116,9 @@ class ImageServiceABC(ABC):
class ImageServiceDependencies: class ImageServiceDependencies:
"""Service dependencies for the ImageService.""" """Service dependencies for the ImageService."""
records: ImageRecordStorageBase image_records: ImageRecordStorageBase
files: ImageFileStorageBase image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
logger: Logger logger: Logger
@ -126,14 +129,16 @@ class ImageServiceDependencies:
self, self,
image_record_storage: ImageRecordStorageBase, image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase, image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase, names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self.records = image_record_storage self.image_records = image_record_storage
self.files = image_file_storage self.image_files = image_file_storage
self.board_image_records = board_image_record_storage
self.metadata = metadata self.metadata = metadata
self.urls = url self.urls = url
self.logger = logger self.logger = logger
@ -144,25 +149,8 @@ class ImageServiceDependencies:
class ImageService(ImageServiceABC): class ImageService(ImageServiceABC):
_services: ImageServiceDependencies _services: ImageServiceDependencies
def __init__( def __init__(self, services: ImageServiceDependencies):
self, self._services = services
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
def create( def create(
self, self,
@ -187,7 +175,7 @@ class ImageService(ImageServiceABC):
try: try:
# TODO: Consider using a transaction here to ensure consistency between storage and database # TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save( self._services.image_records.save(
# Non-nullable fields # Non-nullable fields
image_name=image_name, image_name=image_name,
image_origin=image_origin, image_origin=image_origin,
@ -202,35 +190,15 @@ class ImageService(ImageServiceABC):
metadata=metadata, metadata=metadata,
) )
self._services.files.save( self._services.image_files.save(
image_name=image_name, image_name=image_name,
image=image, image=image,
metadata=metadata, metadata=metadata,
) )
image_url = self._services.urls.get_image_url(image_name) image_dto = self.get_dto(image_name)
thumbnail_url = self._services.urls.get_image_url(image_name, True)
return ImageDTO( return image_dto
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
# Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to save image record") self._services.logger.error("Failed to save image record")
raise raise
@ -247,7 +215,7 @@ class ImageService(ImageServiceABC):
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
try: try:
self._services.records.update(image_name, changes) self._services.image_records.update(image_name, changes)
return self.get_dto(image_name) return self.get_dto(image_name)
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to update image record") self._services.logger.error("Failed to update image record")
@ -258,7 +226,7 @@ class ImageService(ImageServiceABC):
def get_pil_image(self, image_name: str) -> PILImageType: def get_pil_image(self, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_name) return self._services.image_files.get(image_name)
except ImageFileNotFoundException: except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self._services.logger.error("Failed to get image file")
raise raise
@ -268,7 +236,7 @@ class ImageService(ImageServiceABC):
def get_record(self, image_name: str) -> ImageRecord: def get_record(self, image_name: str) -> ImageRecord:
try: try:
return self._services.records.get(image_name) return self._services.image_records.get(image_name)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")
raise raise
@ -278,12 +246,13 @@ class ImageService(ImageServiceABC):
def get_dto(self, image_name: str) -> ImageDTO: def get_dto(self, image_name: str) -> ImageDTO:
try: try:
image_record = self._services.records.get(image_name) image_record = self._services.image_records.get(image_name)
image_dto = image_record_to_dto( image_dto = image_record_to_dto(
image_record, image_record,
self._services.urls.get_image_url(image_name), self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True), self._services.urls.get_image_url(image_name, True),
self._services.board_image_records.get_board_for_image(image_name),
) )
return image_dto return image_dto
@ -296,14 +265,14 @@ class ImageService(ImageServiceABC):
def get_path(self, image_name: str, thumbnail: bool = False) -> str: def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try: try:
return self._services.files.get_path(image_name, thumbnail) return self._services.image_files.get_path(image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
def validate_path(self, path: str) -> bool: def validate_path(self, path: str) -> bool:
try: try:
return self._services.files.validate_path(path) return self._services.image_files.validate_path(path)
except Exception as e: except Exception as e:
self._services.logger.error("Problem validating image path") self._services.logger.error("Problem validating image path")
raise e raise e
@ -322,14 +291,16 @@ class ImageService(ImageServiceABC):
image_origin: Optional[ResourceOrigin] = None, image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None, categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]: ) -> OffsetPaginatedResults[ImageDTO]:
try: try:
results = self._services.records.get_many( results = self._services.image_records.get_many(
offset, offset,
limit, limit,
image_origin, image_origin,
categories, categories,
is_intermediate, is_intermediate,
board_id,
) )
image_dtos = list( image_dtos = list(
@ -338,6 +309,9 @@ class ImageService(ImageServiceABC):
r, r,
self._services.urls.get_image_url(r.image_name), self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True), self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(
r.image_name
),
), ),
results.items, results.items,
) )
@ -355,8 +329,8 @@ class ImageService(ImageServiceABC):
def delete(self, image_name: str): def delete(self, image_name: str):
try: try:
self._services.files.delete(image_name) self._services.image_files.delete(image_name)
self._services.records.delete(image_name) self._services.image_records.delete(image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record") self._services.logger.error(f"Failed to delete image record")
raise raise

View File

@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from logging import Logger from logging import Logger
from invokeai.app.services.images import ImageService from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase from invokeai.app.services.latent_storage import LatentsStorageBase
@ -26,9 +28,9 @@ class InvocationServices:
model_manager: "ModelManager" model_manager: "ModelManager"
restoration: "RestorationServices" restoration: "RestorationServices"
configuration: "InvokeAISettings" configuration: "InvokeAISettings"
images: "ImageService" images: "ImageServiceABC"
boards: "BoardServiceABC"
# NOTE: we must forward-declare any types that include invocations, since invocations can use services board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"] graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"] graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC" processor: "InvocationProcessorABC"
@ -39,7 +41,9 @@ class InvocationServices:
events: "EventServiceBase", events: "EventServiceBase",
logger: "Logger", logger: "Logger",
latents: "LatentsStorageBase", latents: "LatentsStorageBase",
images: "ImageService", images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC", queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"], graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"], graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
@ -52,9 +56,12 @@ class InvocationServices:
self.logger = logger self.logger = logger
self.latents = latents self.latents = latents
self.images = images self.images = images
self.boards = boards
self.board_images = board_images
self.queue = queue self.queue = queue
self.graph_library = graph_library self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager
self.processor = processor self.processor = processor
self.restoration = restoration self.restoration = restoration
self.configuration = configuration self.configuration = configuration
self.boards = boards

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management.model_manager import (
@ -69,19 +69,6 @@ class ModelManagerServiceBase(ABC):
) -> bool: ) -> bool:
pass pass
@abstractmethod
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name and typeof the default model, or None
if none is defined.
"""
pass
@abstractmethod
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
pass
@abstractmethod @abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
@ -270,17 +257,6 @@ class ModelManagerService(ModelManagerServiceBase):
model_type, model_type,
) )
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
return self.mgr.default_model()
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name, base_model, model_type)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
""" """
Given a model name returns a dict-like (OmegaConf) object describing it. Given a model name returns a dict-like (OmegaConf) object describing it.
@ -297,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None model_type: Optional[ModelType] = None
) -> dict: ) -> list[dict]:
# ) -> dict:
""" """
Return a dict of models in the format: Return a list of models.
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
""" """
return self.mgr.list_models(base_model, model_type) return self.mgr.list_models(base_model, model_type)

View File

@ -0,0 +1,62 @@
from typing import Optional, Union
from datetime import datetime
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.util.misc import get_iso_timestamp
class BoardRecord(BaseModel):
"""Deserialized board record."""
board_id: str = Field(description="The unique ID of the board.")
"""The unique ID of the board."""
board_name: str = Field(description="The name of the board.")
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime, str, None] = Field(
description="The deleted timestamp of the board."
)
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(
description="The name of the cover image of the board."
)
"""The name of the cover image of the board."""
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""
cover_image_name: Optional[str] = Field(
description="The name of the board's cover image."
)
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
"""Deserializes a board record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", "unknown")
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
return BoardRecord(
board_id=board_id,
board_name=board_name,
cover_image_name=cover_image_name,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
)

View File

@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO): class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend with URLs.""" """Deserialized image record, enriched for the frontend."""
board_id: Union[str, None] = Field(
description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
pass pass
def image_record_to_dto( def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
) -> ImageDTO: ) -> ImageDTO:
"""Converts an image record to an image DTO.""" """Converts an image record to an image DTO."""
return ImageDTO( return ImageDTO(
**image_record.dict(), **image_record.dict(),
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
board_id=board_id,
) )

View File

@ -266,6 +266,8 @@ class ModelManager(object):
for model_key, model_config in config.items(): for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key) model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config) self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary # check config version number and update on disk/RAM if necessary
@ -445,38 +447,6 @@ class ModelManager(object):
_cache = self.cache, _cache = self.cache,
) )
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
for model_key, model_config in self.models.items():
if model_config.default:
return self.parse_key(model_key)
for model_key, _ in self.models.items():
return self.parse_key(model_key)
else:
return None # TODO: or redo as (None, None, None)
def set_default_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> None:
"""
Set the default model. The change will not take
effect until you call model_manager.commit()
"""
model_key = self.model_key(model_name, base_model, model_type)
if model_key not in self.models:
raise Exception(f"Unknown model: {model_key}")
for cur_model_key, config in self.models.items():
config.default = cur_model_key == model_key
def model_info( def model_info(
self, self,
model_name: str, model_name: str,
@ -503,9 +473,9 @@ class ModelManager(object):
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> Dict[str, Dict[str, str]]: ) -> list[dict]:
""" """
Return a dict of models, in format [base_model][model_type][model_name] Return a list of models.
Please use model_manager.models() to get all the model names, Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model model_manager.model_info('model-name') to get the stanza for the model
@ -513,7 +483,7 @@ class ModelManager(object):
object derived from models.yaml object derived from models.yaml
""" """
models = dict() models = []
for model_key in sorted(self.models, key=str.casefold): for model_key in sorted(self.models, key=str.casefold):
model_config = self.models[model_key] model_config = self.models[model_key]
@ -523,18 +493,16 @@ class ModelManager(object):
if model_type is not None and cur_model_type != model_type: if model_type is not None and cur_model_type != model_type:
continue continue
if cur_base_model not in models: model_dict = dict(
models[cur_base_model] = dict()
if cur_model_type not in models[cur_base_model]:
models[cur_base_model][cur_model_type] = dict()
models[cur_base_model][cur_model_type][cur_model_name] = dict(
**model_config.dict(exclude_defaults=True), **model_config.dict(exclude_defaults=True),
# OpenAPIModelInfoBase
name=cur_model_name, name=cur_model_name,
base_model=cur_base_model, base_model=cur_base_model,
type=cur_model_type, type=cur_model_type,
) )
models.append(model_dict)
return models return models
def print_models(self) -> None: def print_models(self) -> None:
@ -646,7 +614,9 @@ class ModelManager(object):
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config: if model_class.save_to_config:
# TODO: or exclude_unset better fits here? # TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True) data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
# alias for config file
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
yaml_str = OmegaConf.to_yaml(data_to_save) yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path config_file_path = conf_file or self.config_path

View File

@ -1,3 +1,7 @@
import inspect
from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
@ -29,10 +33,63 @@ MODEL_CLASSES = {
#}, #},
} }
def get_all_model_configs(): MODEL_CONFIGS = list()
configs = set() OPENAPI_MODEL_CONFIGS = list()
for models in MODEL_CLASSES.values():
for _, model in models.items(): class OpenAPIModelInfoBase(BaseModel):
configs.update(model._get_configs().values()) name: str
configs.discard(None) base_model: BaseModelType
return list(configs) # TODO: set, list or tuple type: ModelType
for base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
type=Literal[model_type.value],
),
))
#globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = list()
for model_config in MODEL_CONFIGS:
fields = inspect.get_annotations(model_config)
try:
field = fields["model_format"]
except:
raise Exception("format field not found")
# model_format: None
# model_format: SomeModelFormat
# model_format: Literal[SomeModelFormat.Diffusers]
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
enums.append(type(field.__args__[0]))
elif field is None:
pass
else:
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@ -48,12 +48,10 @@ class ModelError(str, Enum):
class ModelConfigBase(BaseModel): class ModelConfigBase(BaseModel):
path: str # or Path path: str # or Path
#name: str # not included as present in model key
description: Optional[str] = Field(None) description: Optional[str] = Field(None)
format: Optional[str] = Field(None) model_format: Optional[str] = Field(None)
default: Optional[bool] = Field(False)
# do not save to config # do not save to config
error: Optional[ModelError] = Field(None, exclude=True) error: Optional[ModelError] = Field(None)
class Config: class Config:
use_enum_values = True use_enum_values = True
@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta):
def _hf_definition_to_type(self, subtypes: List[str]) -> Type: def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2: if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!") raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]: if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]] res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:] subtypes = subtypes[1:]
@ -122,47 +125,41 @@ class ModelBase(metaclass=ABCMeta):
continue continue
fields = inspect.get_annotations(value) fields = inspect.get_annotations(value)
if "format" not in fields: try:
raise Exception("Invalid config definition - format field not found") field = fields["model_format"]
except:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
format_type = typing.get_origin(fields["format"]) if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
if format_type not in {None, Literal, Union}: for model_format in field:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") configs[model_format.value] = value
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__): elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}") for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
if format_type == Union:
f_fields = fields["format"].__args__
else: else:
f_fields = (fields["format"],) raise Exception(f"Unsupported format definition in {cls.__qualname__}")
for field in f_fields:
if field is None:
format_name = None
else:
format_name = field.__args__[0]
configs[format_name] = value # TODO: error when override(multiple)?
cls.__configs = configs cls.__configs = configs
return cls.__configs return cls.__configs
@classmethod @classmethod
def create_config(cls, **kwargs) -> ModelConfigBase: def create_config(cls, **kwargs) -> ModelConfigBase:
if "format" not in kwargs: if "model_format" not in kwargs:
raise Exception("Field 'format' not found in model config") raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs() configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs) return configs[kwargs["model_format"]](**kwargs)
@classmethod @classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase: def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config( return cls.create_config(
path=path, path=path,
format=cls.detect_format(path), model_format=cls.detect_format(path),
) )
@classmethod @classmethod

View File

@ -1,5 +1,6 @@
import os import os
import torch import torch
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
@ -14,12 +15,16 @@ from .base import (
classproperty, classproperty,
) )
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase): class ControlNetModel(ModelBase):
#model_class: Type #model_class: Type
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]] model_format: ControlNetModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet assert model_type == ModelType.ControlNet
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if os.path.isdir(path): if os.path.isdir(path):
return "diffusers" return ControlNetModelFormat.Diffusers
else: else:
return "checkpoint" return ControlNetModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) != "diffusers": if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImlemetedError("Checkpoint controlnet models currently unsupported") raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else: else:
return model_path return model_path

View File

@ -1,5 +1,6 @@
import os import os
import torch import torch
from enum import Enum
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
ModelBase, ModelBase,
@ -12,11 +13,15 @@ from .base import (
# TODO: naming # TODO: naming
from ..lora import LoRAModel as LoRAModelRaw from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase): class LoRAModel(ModelBase):
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]] model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora assert model_type == ModelType.Lora
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if os.path.isdir(path): if os.path.isdir(path):
return "diffusers" return LoRAModelFormat.Diffusers
else: else:
return "lycoris" return LoRAModelFormat.LyCORIS
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
config: ModelConfigBase, config: ModelConfigBase,
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) == "diffusers": if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit # TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported") raise NotImplementedError("Diffusers lora not supported")
else: else:

View File

@ -1,5 +1,6 @@
import os import os
import json import json
from enum import Enum
from pydantic import Field from pydantic import Field
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
@ -19,16 +20,19 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel): class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"] model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs): def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path) model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None) ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint": if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path: if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint) checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers": elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json") unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path): if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f: with open(unet_config_path, "r") as f:
@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel):
return cls.create_config( return cls.create_config(
path=path, path=path,
format=model_format, model_format=model_format,
config=ckpt_config_path, config=ckpt_config_path,
variant=variant, variant=variant,
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if os.path.isdir(model_path): if os.path.isdir(model_path):
return "diffusers" return StableDiffusion1ModelFormat.Diffusers
else: else:
return "checkpoint" return StableDiffusion1ModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
else: else:
return model_path return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel): class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly # TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType
upcast_attention: bool upcast_attention: bool
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"] model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
variant: ModelVariantType variant: ModelVariantType
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs): def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path) model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None) ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint": if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path: if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path) ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"] ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint) checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers": elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json") unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path): if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f: with open(unet_config_path, "r") as f:
@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel):
return cls.create_config( return cls.create_config(
path=path, path=path,
format=model_format, model_format=model_format,
config=ckpt_config_path, config=ckpt_config_path,
variant=variant, variant=variant,
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
@classmethod @classmethod
def detect_format(cls, model_path: str): def detect_format(cls, model_path: str):
if os.path.isdir(model_path): if os.path.isdir(model_path):
return "diffusers" return StableDiffusion2ModelFormat.Diffusers
else: else:
return "checkpoint" return StableDiffusion2ModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -281,8 +288,8 @@ def _convert_ckpt_and_cache(
prediction_type = SchedulerPredictionType.Epsilon prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2: elif version == BaseModelType.StableDiffusion2:
upcast_attention = config.upcast_attention upcast_attention = model_config.upcast_attention
prediction_type = config.prediction_type prediction_type = model_config.prediction_type
else: else:
raise Exception(f"Unknown model provided: {version}") raise Exception(f"Unknown model provided: {version}")

View File

@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: None model_format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion assert model_type == ModelType.TextualInversion

View File

@ -1,5 +1,7 @@
import os import os
import torch import torch
import safetensors
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Literal from typing import Optional, Union, Literal
from .base import ( from .base import (
@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase): class VaeModel(ModelBase):
#vae_class: Type #vae_class: Type
#model_size: int #model_size: int
class Config(ModelConfigBase): class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]] model_format: VaeModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae assert model_type == ModelType.Vae
@ -70,9 +76,9 @@ class VaeModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if os.path.isdir(path): if os.path.isdir(path):
return "diffusers" return VaeModelFormat.Diffusers
else: else:
return "checkpoint" return VaeModelFormat.Checkpoint
@classmethod @classmethod
def convert_if_required( def convert_if_required(
@ -82,7 +88,7 @@ class VaeModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
if cls.detect_format(model_path) != "diffusers": if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
return _convert_vae_ckpt_and_cache( return _convert_vae_ckpt_and_cache(
weights_path=model_path, weights_path=model_path,
output_path=output_path, output_path=output_path,

View File

@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { useListModelsQuery } from 'services/apiSlice';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -45,6 +47,18 @@ const App = ({
const isApplicationReady = useIsApplicationReady(); const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet',
});
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
const { data: embeddingModels } = useListModelsQuery({
model_type: 'embedding',
});
const [loadingOverridden, setLoadingOverridden] = useState(false); const [loadingOverridden, setLoadingOverridden] = useState(false);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -143,6 +157,7 @@ const App = ({
</Portal> </Portal>
</Grid> </Grid>
<DeleteImageModal /> <DeleteImageModal />
<UpdateImageBoardModal />
<Toaster /> <Toaster />
<GlobalHotkeys /> <GlobalHotkeys />
</> </>

View File

@ -21,6 +21,8 @@ import {
DeleteImageContext, DeleteImageContext,
DeleteImageContextProvider, DeleteImageContextProvider,
} from 'app/contexts/DeleteImageContext'; } from 'app/contexts/DeleteImageContext';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
const App = lazy(() => import('./App')); const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@ -76,11 +78,13 @@ const InvokeAIUI = ({
<ThemeLocaleProvider> <ThemeLocaleProvider>
<ImageDndContext> <ImageDndContext>
<DeleteImageContextProvider> <DeleteImageContextProvider>
<App <AddImageToBoardContextProvider>
config={config} <App
headerComponent={headerComponent} config={config}
setIsReady={setIsReady} headerComponent={headerComponent}
/> setIsReady={setIsReady}
/>
</AddImageToBoardContextProvider>
</DeleteImageContextProvider> </DeleteImageContextProvider>
</ImageDndContext> </ImageDndContext>
</ThemeLocaleProvider> </ThemeLocaleProvider>

View File

@ -0,0 +1,89 @@
import { useDisclosure } from '@chakra-ui/react';
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
import { ImageDTO } from 'services/api';
import { useAddImageToBoardMutation } from 'services/apiSlice';
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};
type AddImageToBoardContextValue = {
/**
* Whether the move image dialog is open.
*/
isOpen: boolean;
/**
* Closes the move image dialog.
*/
onClose: () => void;
/**
* The image pending movement
*/
image?: ImageDTO;
onClickAddToBoard: (image: ImageDTO) => void;
handleAddToBoard: (boardId: string) => void;
};
export const AddImageToBoardContext =
createContext<AddImageToBoardContextValue>({
isOpen: false,
onClose: () => undefined,
onClickAddToBoard: () => undefined,
handleAddToBoard: () => undefined,
});
type Props = PropsWithChildren;
export const AddImageToBoardContextProvider = (props: Props) => {
const [imageToMove, setImageToMove] = useState<ImageDTO>();
const { isOpen, onOpen, onClose } = useDisclosure();
const [addImageToBoard, result] = useAddImageToBoardMutation();
// Clean up after deleting or dismissing the modal
const closeAndClearImageToDelete = useCallback(() => {
setImageToMove(undefined);
onClose();
}, [onClose]);
const onClickAddToBoard = useCallback(
(image?: ImageDTO) => {
if (!image) {
return;
}
setImageToMove(image);
onOpen();
},
[setImageToMove, onOpen]
);
const handleAddToBoard = useCallback(
(boardId: string) => {
if (imageToMove) {
addImageToBoard({
board_id: boardId,
image_name: imageToMove.image_name,
});
closeAndClearImageToDelete();
}
},
[addImageToBoard, closeAndClearImageToDelete, imageToMove]
);
return (
<AddImageToBoardContext.Provider
value={{
isOpen,
image: imageToMove,
onClose: closeAndClearImageToDelete,
onClickAddToBoard,
handleAddToBoard,
}}
>
{props.children}
</AddImageToBoardContext.Provider>
);
};

View File

@ -35,25 +35,23 @@ export const selectImageUsage = createSelector(
(state: RootState, image_name?: string) => image_name, (state: RootState, image_name?: string) => image_name,
], ],
(generation, canvas, nodes, controlNet, image_name) => { (generation, canvas, nodes, controlNet, image_name) => {
const isInitialImage = generation.initialImage?.image_name === image_name; const isInitialImage = generation.initialImage?.imageName === image_name;
const isCanvasImage = canvas.layerState.objects.some( const isCanvasImage = canvas.layerState.objects.some(
(obj) => obj.kind === 'image' && obj.image.image_name === image_name (obj) => obj.kind === 'image' && obj.imageName === image_name
); );
const isNodesImage = nodes.nodes.some((node) => { const isNodesImage = nodes.nodes.some((node) => {
return some( return some(
node.data.inputs, node.data.inputs,
(input) => (input) => input.type === 'image' && input.value === image_name
input.type === 'image' && input.value?.image_name === image_name
); );
}); });
const isControlNetImage = some( const isControlNetImage = some(
controlNet.controlNets, controlNet.controlNets,
(c) => (c) =>
c.controlImage?.image_name === image_name || c.controlImage === image_name || c.processedControlImage === image_name
c.processedControlImage?.image_name === image_name
); );
const imageUsage: ImageUsage = { const imageUsage: ImageUsage = {

View File

@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist'; import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist'; import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist'; import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist'; import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist'; import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es'; import { omit } from 'lodash-es';
@ -18,7 +17,6 @@ const serializationDenylist: {
gallery: galleryPersistDenylist, gallery: galleryPersistDenylist,
generation: generationPersistDenylist, generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist, lightbox: lightboxPersistDenylist,
models: modelsPersistDenylist,
nodes: nodesPersistDenylist, nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist, postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist, system: systemPersistDenylist,

View File

@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice'; import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice'; import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice'; import { initialConfigState } from 'features/system/store/configSlice';
import { initialModelsState } from 'features/system/store/modelSlice';
import { initialSystemState } from 'features/system/store/systemSlice'; import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice'; import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice'; import { initialUIState } from 'features/ui/store/uiSlice';
@ -21,7 +20,6 @@ const initialStates: {
gallery: initialGalleryState, gallery: initialGalleryState,
generation: initialGenerationState, generation: initialGenerationState,
lightbox: initialLightboxState, lightbox: initialLightboxState,
models: initialModelsState,
nodes: initialNodesState, nodes: initialNodesState,
postprocessing: initialPostprocessingState, postprocessing: initialPostprocessingState,
system: initialSystemState, system: initialSystemState,

View File

@ -73,6 +73,15 @@ import { addImageCategoriesChangedListener } from './listeners/imageCategoriesCh
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed'; import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess'; import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect'; import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
import {
addImageAddedToBoardFulfilledListener,
addImageAddedToBoardRejectedListener,
} from './listeners/imageAddedToBoard';
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import {
addImageRemovedFromBoardFulfilledListener,
addImageRemovedFromBoardRejectedListener,
} from './listeners/imageRemovedFromBoard';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -92,6 +101,12 @@ export type AppListenerEffect = ListenerEffect<
AppDispatch AppDispatch
>; >;
/**
* The RTK listener middleware is a lightweight alternative sagas/observables.
*
* Most side effect logic should live in a listener.
*/
// Image uploaded // Image uploaded
addImageUploadedFulfilledListener(); addImageUploadedFulfilledListener();
addImageUploadedRejectedListener(); addImageUploadedRejectedListener();
@ -183,3 +198,10 @@ addControlNetAutoProcessListener();
// Update image URLs on connect // Update image URLs on connect
addUpdateImageUrlsOnConnectListener(); addUpdateImageUrlsOnConnectListener();
// Boards
addImageAddedToBoardFulfilledListener();
addImageAddedToBoardRejectedListener();
addImageRemovedFromBoardFulfilledListener();
addImageRemovedFromBoardRejectedListener();
addBoardIdSelectedListener();

View File

@ -0,0 +1,99 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { boardIdSelected } from 'features/gallery/store/boardSlice';
import { selectImagesAll } from 'features/gallery/store/imagesSlice';
import { IMAGES_PER_PAGE, receivedPageOfImages } from 'services/thunks/image';
import { api } from 'services/apiSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'boards' });
export const addBoardIdSelectedListener = () => {
startAppListening({
actionCreator: boardIdSelected,
effect: (action, { getState, dispatch }) => {
const boardId = action.payload;
// we need to check if we need to fetch more images
const state = getState();
const allImages = selectImagesAll(state);
if (!boardId) {
// a board was unselected
dispatch(imageSelected(allImages[0]?.image_name));
return;
}
const { categories } = state.images;
const filteredImages = allImages.filter((i) => {
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
return isInCategory && isInSelectedBoard;
});
// get the board from the cache
const { data: boards } = api.endpoints.listAllBoards.select()(state);
const board = boards?.find((b) => b.board_id === boardId);
if (!board) {
// can't find the board in cache...
dispatch(imageSelected(allImages[0]?.image_name));
return;
}
dispatch(imageSelected(board.cover_image_name));
// if we haven't loaded one full page of images from this board, load more
if (
filteredImages.length < board.image_count &&
filteredImages.length < IMAGES_PER_PAGE
) {
dispatch(receivedPageOfImages({ categories, boardId }));
}
},
});
};
export const addBoardIdSelected_changeSelectedImage_listener = () => {
startAppListening({
actionCreator: boardIdSelected,
effect: (action, { getState, dispatch }) => {
const boardId = action.payload;
const state = getState();
// we need to check if we need to fetch more images
if (!boardId) {
// a board was unselected - we don't need to do anything
return;
}
const { categories } = state.images;
const filteredImages = selectImagesAll(state).filter((i) => {
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
return isInCategory && isInSelectedBoard;
});
// get the board from the cache
const { data: boards } = api.endpoints.listAllBoards.select()(state);
const board = boards?.find((b) => b.board_id === boardId);
if (!board) {
// can't find the board in cache...
return;
}
// if we haven't loaded one full page of images from this board, load more
if (
filteredImages.length < board.image_count &&
filteredImages.length < IMAGES_PER_PAGE
) {
dispatch(receivedPageOfImages({ categories, boardId }));
}
},
});
};

View File

@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => {
[controlNet.processorNode.id]: { [controlNet.processorNode.id]: {
...controlNet.processorNode, ...controlNet.processorNode,
is_intermediate: true, is_intermediate: true,
image: pick(controlNet.controlImage, ['image_name']), image: { image_name: controlNet.controlImage },
}, },
}, },
}; };
@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => {
dispatch( dispatch(
controlNetProcessedImageChanged({ controlNetProcessedImageChanged({
controlNetId, controlNetId,
processedControlImage, processedControlImage: processedControlImage.image_name,
}) })
); );
} }

View File

@ -0,0 +1,40 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image';
import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'boards' });
export const addImageAddedToBoardFulfilledListener = () => {
startAppListening({
matcher: api.endpoints.addImageToBoard.matchFulfilled,
effect: (action, { getState, dispatch }) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
moduleLog.debug(
{ data: { board_id, image_name } },
'Image added to board'
);
dispatch(
imageMetadataReceived({
imageName: image_name,
})
);
},
});
};
export const addImageAddedToBoardRejectedListener = () => {
startAppListening({
matcher: api.endpoints.addImageToBoard.matchRejected,
effect: (action, { getState, dispatch }) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
moduleLog.debug(
{ data: { board_id, image_name } },
'Problem adding image to board'
);
},
});
};

View File

@ -12,12 +12,16 @@ export const addImageCategoriesChangedListener = () => {
startAppListening({ startAppListening({
actionCreator: imageCategoriesChanged, actionCreator: imageCategoriesChanged,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const filteredImagesCount = selectFilteredImagesAsArray( const state = getState();
getState() const filteredImagesCount = selectFilteredImagesAsArray(state).length;
).length;
if (!filteredImagesCount) { if (!filteredImagesCount) {
dispatch(receivedPageOfImages()); dispatch(
receivedPageOfImages({
categories: action.payload,
boardId: state.boards.selectedBoardId,
})
);
} }
}, },
}); });

View File

@ -6,15 +6,15 @@ import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { import {
imageRemoved, imageRemoved,
selectImagesEntities,
selectImagesIds, selectImagesIds,
} from 'features/gallery/store/imagesSlice'; } from 'features/gallery/store/imagesSlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); const moduleLog = log.child({ namespace: 'image' });
/** /**
* Called when the user requests an image deletion * Called when the user requests an image deletion
@ -22,7 +22,7 @@ const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
export const addRequestedImageDeletionListener = () => { export const addRequestedImageDeletionListener = () => {
startAppListening({ startAppListening({
actionCreator: requestedImageDeletion, actionCreator: requestedImageDeletion,
effect: (action, { dispatch, getState }) => { effect: async (action, { dispatch, getState, condition }) => {
const { image, imageUsage } = action.payload; const { image, imageUsage } = action.payload;
const { image_name } = image; const { image_name } = image;
@ -30,9 +30,8 @@ export const addRequestedImageDeletionListener = () => {
const state = getState(); const state = getState();
const selectedImage = state.gallery.selectedImage; const selectedImage = state.gallery.selectedImage;
if (selectedImage && selectedImage.image_name === image_name) { if (selectedImage === image_name) {
const ids = selectImagesIds(state); const ids = selectImagesIds(state);
const entities = selectImagesEntities(state);
const deletedImageIndex = ids.findIndex( const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name (result) => result.toString() === image_name
@ -48,10 +47,8 @@ export const addRequestedImageDeletionListener = () => {
const newSelectedImageId = filteredIds[newSelectedImageIndex]; const newSelectedImageId = filteredIds[newSelectedImageIndex];
const newSelectedImage = entities[newSelectedImageId];
if (newSelectedImageId) { if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImage)); dispatch(imageSelected(newSelectedImageId as string));
} else { } else {
dispatch(imageSelected()); dispatch(imageSelected());
} }
@ -79,7 +76,21 @@ export const addRequestedImageDeletionListener = () => {
dispatch(imageRemoved(image_name)); dispatch(imageRemoved(image_name));
// Delete from server // Delete from server
dispatch(imageDeleted({ imageName: image_name })); const { requestId } = dispatch(imageDeleted({ imageName: image_name }));
// Wait for successful deletion, then trigger boards to re-fetch
const wasImageDeleted = await condition(
(action): action is ReturnType<typeof imageDeleted.fulfilled> =>
imageDeleted.fulfilled.match(action) &&
action.meta.requestId === requestId,
30000
);
if (wasImageDeleted) {
dispatch(
api.util.invalidateTags([{ type: 'Board', id: image.board_id }])
);
}
}, },
}); });
}; };

View File

@ -0,0 +1,40 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image';
import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'boards' });
export const addImageRemovedFromBoardFulfilledListener = () => {
startAppListening({
matcher: api.endpoints.removeImageFromBoard.matchFulfilled,
effect: (action, { getState, dispatch }) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
moduleLog.debug(
{ data: { board_id, image_name } },
'Image added to board'
);
dispatch(
imageMetadataReceived({
imageName: image_name,
})
);
},
});
};
export const addImageRemovedFromBoardRejectedListener = () => {
startAppListening({
matcher: api.endpoints.removeImageFromBoard.matchRejected,
effect: (action, { getState, dispatch }) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
moduleLog.debug(
{ data: { board_id, image_name } },
'Problem adding image to board'
);
},
});
};

View File

@ -46,7 +46,12 @@ export const addImageUploadedFulfilledListener = () => {
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') { if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
const { controlNetId } = postUploadAction; const { controlNetId } = postUploadAction;
dispatch(controlNetImageChanged({ controlNetId, controlImage: image })); dispatch(
controlNetImageChanged({
controlNetId,
controlImage: image.image_name,
})
);
return; return;
} }

View File

@ -1,9 +1,8 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/thunks/image'; import { receivedPageOfImages } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema'; import { receivedOpenAPISchema } from 'services/thunks/schema';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
@ -15,16 +14,17 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected'); moduleLog.debug({ timestamp }, 'Connected');
const { models, nodes, config, images } = getState(); const { nodes, config, images } = getState();
const { disabledTabs } = config; const { disabledTabs } = config;
if (!images.ids.length) { if (!images.ids.length) {
dispatch(receivedPageOfImages()); dispatch(
} receivedPageOfImages({
categories: ['general'],
if (!models.ids.length) { isIntermediate: false,
dispatch(receivedModels()); })
);
} }
if (!nodes.schema && !disabledTabs.includes('nodes')) { if (!nodes.schema && !disabledTabs.includes('nodes')) {

View File

@ -9,6 +9,7 @@ import { imageMetadataReceived } from 'services/thunks/image';
import { sessionCanceled } from 'services/thunks/session'; import { sessionCanceled } from 'services/thunks/session';
import { isImageOutput } from 'services/types/guards'; import { isImageOutput } from 'services/types/guards';
import { progressImageSet } from 'features/system/store/systemSlice'; import { progressImageSet } from 'features/system/store/systemSlice';
import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image']; const nodeDenylist = ['dataURL_image'];
@ -24,7 +25,8 @@ export const addInvocationCompleteEventListener = () => {
const sessionId = action.payload.data.graph_execution_state_id; const sessionId = action.payload.data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system; const { cancelType, isCancelScheduled, boardIdToAddTo } =
getState().system;
// Handle scheduled cancelation // Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) { if (cancelType === 'scheduled' && isCancelScheduled) {
@ -57,6 +59,15 @@ export const addInvocationCompleteEventListener = () => {
dispatch(addImageToStagingArea(imageDTO)); dispatch(addImageToStagingArea(imageDTO));
} }
if (boardIdToAddTo && !imageDTO.is_intermediate) {
dispatch(
api.endpoints.addImageToBoard.initiate({
board_id: boardIdToAddTo,
image_name,
})
);
}
dispatch(progressImageSet(null)); dispatch(progressImageSet(null));
} }
// pass along the socket event as an application action // pass along the socket event as an application action

View File

@ -22,15 +22,15 @@ const selectAllUsedImages = createSelector(
selectImagesEntities, selectImagesEntities,
], ],
(generation, canvas, nodes, controlNet, imageEntities) => { (generation, canvas, nodes, controlNet, imageEntities) => {
const allUsedImages: ImageDTO[] = []; const allUsedImages: string[] = [];
if (generation.initialImage) { if (generation.initialImage) {
allUsedImages.push(generation.initialImage); allUsedImages.push(generation.initialImage.imageName);
} }
canvas.layerState.objects.forEach((obj) => { canvas.layerState.objects.forEach((obj) => {
if (obj.kind === 'image') { if (obj.kind === 'image') {
allUsedImages.push(obj.image); allUsedImages.push(obj.imageName);
} }
}); });
@ -53,7 +53,7 @@ const selectAllUsedImages = createSelector(
forEach(imageEntities, (image) => { forEach(imageEntities, (image) => {
if (image) { if (image) {
allUsedImages.push(image); allUsedImages.push(image.image_name);
} }
}); });
@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => {
`Fetching new image URLs for ${allUsedImages.length} images` `Fetching new image URLs for ${allUsedImages.length} images`
); );
allUsedImages.forEach(({ image_name }) => { allUsedImages.forEach((image_name) => {
dispatch( dispatch(
imageUrlsReceived({ imageUrlsReceived({
imageName: image_name, imageName: image_name,

View File

@ -5,40 +5,39 @@ import {
configureStore, configureStore,
} from '@reduxjs/toolkit'; } from '@reduxjs/toolkit';
import { rememberReducer, rememberEnhancer } from 'redux-remember';
import dynamicMiddlewares from 'redux-dynamic-middlewares'; import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import canvasReducer from 'features/canvas/store/canvasSlice'; import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import galleryReducer from 'features/gallery/store/gallerySlice'; import galleryReducer from 'features/gallery/store/gallerySlice';
import imagesReducer from 'features/gallery/store/imagesSlice'; import imagesReducer from 'features/gallery/store/imagesSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice'; // import sessionReducer from 'features/system/store/sessionSlice';
import configReducer from 'features/system/store/configSlice';
import uiReducer from 'features/ui/store/uiSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice'; import nodesReducer from 'features/nodes/store/nodesSlice';
import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize'; import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize'; import { unserialize } from './enhancers/reduxRemember/unserialize';
import { LOCALSTORAGE_PREFIX } from './constants'; import { api } from 'services/apiSlice';
const allReducers = { const allReducers = {
canvas: canvasReducer, canvas: canvasReducer,
gallery: galleryReducer, gallery: galleryReducer,
generation: generationReducer, generation: generationReducer,
lightbox: lightboxReducer, lightbox: lightboxReducer,
models: modelsReducer,
nodes: nodesReducer, nodes: nodesReducer,
postprocessing: postprocessingReducer, postprocessing: postprocessingReducer,
system: systemReducer, system: systemReducer,
@ -47,7 +46,9 @@ const allReducers = {
hotkeys: hotkeysReducer, hotkeys: hotkeysReducer,
images: imagesReducer, images: imagesReducer,
controlNet: controlNetReducer, controlNet: controlNetReducer,
boards: boardsReducer,
// session: sessionReducer, // session: sessionReducer,
[api.reducerPath]: api.reducer,
}; };
const rootReducer = combineReducers(allReducers); const rootReducer = combineReducers(allReducers);
@ -59,12 +60,12 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'gallery', 'gallery',
'generation', 'generation',
'lightbox', 'lightbox',
// 'models',
'nodes', 'nodes',
'postprocessing', 'postprocessing',
'system', 'system',
'ui', 'ui',
'controlNet', 'controlNet',
// 'boards',
// 'hotkeys', // 'hotkeys',
// 'config', // 'config',
]; ];
@ -84,6 +85,7 @@ export const store = configureStore({
immutableCheck: false, immutableCheck: false,
serializableCheck: false, serializableCheck: false,
}) })
.concat(api.middleware)
.concat(dynamicMiddlewares) .concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware), .prepend(listenerMiddleware.middleware),
devTools: { devTools: {

View File

@ -9,7 +9,7 @@ import {
import { useDraggable, useDroppable } from '@dnd-kit/core'; import { useDraggable, useDroppable } from '@dnd-kit/core';
import { useCombinedRefs } from '@dnd-kit/utilities'; import { useCombinedRefs } from '@dnd-kit/utilities';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { IAIImageFallback } from 'common/components/IAIImageFallback'; import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { AnimatePresence } from 'framer-motion'; import { AnimatePresence } from 'framer-motion';
import { ReactElement, SyntheticEvent, useCallback } from 'react'; import { ReactElement, SyntheticEvent, useCallback } from 'react';
@ -53,7 +53,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
isDropDisabled = false, isDropDisabled = false,
isDragDisabled = false, isDragDisabled = false,
isUploadDisabled = false, isUploadDisabled = false,
fallback = <IAIImageFallback />, fallback = <IAIImageLoadingFallback />,
payloadImage, payloadImage,
minSize = 24, minSize = 24,
postUploadAction, postUploadAction,

View File

@ -1,10 +1,20 @@
import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react'; import {
As,
Flex,
FlexProps,
Icon,
IconProps,
Spinner,
SpinnerProps,
} from '@chakra-ui/react';
import { ReactElement } from 'react';
import { FaImage } from 'react-icons/fa';
type Props = FlexProps & { type Props = FlexProps & {
spinnerProps?: SpinnerProps; spinnerProps?: SpinnerProps;
}; };
export const IAIImageFallback = (props: Props) => { export const IAIImageLoadingFallback = (props: Props) => {
const { spinnerProps, ...rest } = props; const { spinnerProps, ...rest } = props;
const { sx, ...restFlexProps } = rest; const { sx, ...restFlexProps } = rest;
return ( return (
@ -25,3 +35,35 @@ export const IAIImageFallback = (props: Props) => {
</Flex> </Flex>
); );
}; };
type IAINoImageFallbackProps = {
flexProps?: FlexProps;
iconProps?: IconProps;
as?: As;
};
export const IAINoImageFallback = (props: IAINoImageFallbackProps) => {
const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} };
const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} };
return (
<Flex
sx={{
bg: 'base.900',
opacity: 0.7,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
...flexSx,
}}
{...restFlexProps}
>
<Icon
as={props.as ?? FaImage}
sx={{ color: 'base.700', ...iconSx }}
{...restIconProps}
/>
</Flex>
);
};

View File

@ -1,14 +1,21 @@
import { Image } from 'react-konva'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { Image, Rect } from 'react-konva';
import { useGetImageDTOQuery } from 'services/apiSlice';
import useImage from 'use-image'; import useImage from 'use-image';
import { CanvasImage } from '../store/canvasTypes';
type IAICanvasImageProps = { type IAICanvasImageProps = {
url: string; canvasImage: CanvasImage;
x: number;
y: number;
}; };
const IAICanvasImage = (props: IAICanvasImageProps) => { const IAICanvasImage = (props: IAICanvasImageProps) => {
const { url, x, y } = props; const { width, height, x, y, imageName } = props.canvasImage;
const [image] = useImage(url, 'anonymous'); const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
if (!imageDTO) {
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
}
return <Image x={x} y={y} image={image} listening={false} />; return <Image x={x} y={y} image={image} listening={false} />;
}; };

View File

@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => {
<Group name="outpainting-objects" listening={false}> <Group name="outpainting-objects" listening={false}>
{objects.map((obj, i) => { {objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) { if (isCanvasBaseImage(obj)) {
return ( return <IAICanvasImage key={i} canvasImage={obj} />;
<IAICanvasImage
key={i}
x={obj.x}
y={obj.y}
url={obj.image.image_url}
/>
);
} else if (isCanvasBaseLine(obj)) { } else if (isCanvasBaseLine(obj)) {
const line = ( const line = (
<Line <Line

View File

@ -59,11 +59,7 @@ const IAICanvasStagingArea = (props: Props) => {
return ( return (
<Group {...rest}> <Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && ( {shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage <IAICanvasImage canvasImage={currentStagingAreaImage} />
url={currentStagingAreaImage.image.image_url}
x={x}
y={y}
/>
)} )}
{shouldShowStagingOutline && ( {shouldShowStagingOutline && (
<Group> <Group>

View File

@ -203,7 +203,7 @@ export const canvasSlice = createSlice({
y: 0, y: 0,
width: width, width: width,
height: height, height: height,
image: image, imageName: image.image_name,
}, },
], ],
}; };
@ -325,7 +325,7 @@ export const canvasSlice = createSlice({
kind: 'image', kind: 'image',
layer: 'base', layer: 'base',
...state.layerState.stagingArea.boundingBox, ...state.layerState.stagingArea.boundingBox,
image, imageName: image.image_name,
}); });
state.layerState.stagingArea.selectedImageIndex = state.layerState.stagingArea.selectedImageIndex =
@ -865,25 +865,25 @@ export const canvasSlice = createSlice({
state.doesCanvasNeedScaling = true; state.doesCanvasNeedScaling = true;
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload; // const { image_name, image_url, thumbnail_url } = action.payload;
state.layerState.objects.forEach((object) => { // state.layerState.objects.forEach((object) => {
if (object.kind === 'image') { // if (object.kind === 'image') {
if (object.image.image_name === image_name) { // if (object.image.image_name === image_name) {
object.image.image_url = image_url; // object.image.image_url = image_url;
object.image.thumbnail_url = thumbnail_url; // object.image.thumbnail_url = thumbnail_url;
} // }
} // }
}); // });
state.layerState.stagingArea.images.forEach((stagedImage) => { // state.layerState.stagingArea.images.forEach((stagedImage) => {
if (stagedImage.image.image_name === image_name) { // if (stagedImage.image.image_name === image_name) {
stagedImage.image.image_url = image_url; // stagedImage.image.image_url = image_url;
stagedImage.image.thumbnail_url = thumbnail_url; // stagedImage.image.thumbnail_url = thumbnail_url;
} // }
}); // });
}); // });
}, },
}); });

View File

@ -38,7 +38,7 @@ export type CanvasImage = {
y: number; y: number;
width: number; width: number;
height: number; height: number;
image: ImageDTO; imageName: string;
}; };
export type CanvasMaskLine = { export type CanvasMaskLine = {

View File

@ -11,9 +11,11 @@ import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { AnimatePresence, motion } from 'framer-motion'; import { AnimatePresence, motion } from 'framer-motion';
import { IAIImageFallback } from 'common/components/IAIImageFallback'; import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa'; import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector( const selector = createSelector(
controlNetSelector, controlNetSelector,
@ -31,24 +33,45 @@ type Props = {
const ControlNetImagePreview = (props: Props) => { const ControlNetImagePreview = (props: Props) => {
const { imageSx } = props; const { imageSx } = props;
const { controlNetId, controlImage, processedControlImage, processorType } = const {
props.controlNet; controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector); const { pendingControlImages } = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false); const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const {
data: controlImage,
isLoading: isLoadingControlImage,
isError: isErrorControlImage,
isSuccess: isSuccessControlImage,
} = useGetImageDTOQuery(controlImageName ?? skipToken);
const {
data: processedControlImage,
isLoading: isLoadingProcessedControlImage,
isError: isErrorProcessedControlImage,
isSuccess: isSuccessProcessedControlImage,
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
const handleDrop = useCallback( const handleDrop = useCallback(
(droppedImage: ImageDTO) => { (droppedImage: ImageDTO) => {
if (controlImage?.image_name === droppedImage.image_name) { if (controlImageName === droppedImage.image_name) {
return; return;
} }
setIsMouseOverImage(false); setIsMouseOverImage(false);
dispatch( dispatch(
controlNetImageChanged({ controlNetId, controlImage: droppedImage }) controlNetImageChanged({
controlNetId,
controlImage: droppedImage.image_name,
})
); );
}, },
[controlImage, controlNetId, dispatch] [controlImageName, controlNetId, dispatch]
); );
const handleResetControlImage = useCallback(() => { const handleResetControlImage = useCallback(() => {
@ -150,7 +173,7 @@ const ControlNetImagePreview = (props: Props) => {
h: 'full', h: 'full',
}} }}
> >
<IAIImageFallback /> <IAIImageLoadingFallback />
</Box> </Box>
)} )}
{controlImage && ( {controlImage && (

View File

@ -39,8 +39,8 @@ export type ControlNetConfig = {
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
controlImage: ImageDTO | null; controlImage: string | null;
processedControlImage: ImageDTO | null; processedControlImage: string | null;
processorType: ControlNetProcessorType; processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode; processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean; shouldAutoConfig: boolean;
@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({
}, },
controlNetAddedFromImage: ( controlNetAddedFromImage: (
state, state,
action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }> action: PayloadAction<{ controlNetId: string; controlImage: string }>
) => { ) => {
const { controlNetId, controlImage } = action.payload; const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = { state.controlNets[controlNetId] = {
@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({
state, state,
action: PayloadAction<{ action: PayloadAction<{
controlNetId: string; controlNetId: string;
controlImage: ImageDTO | null; controlImage: string | null;
}> }>
) => { ) => {
const { controlNetId, controlImage } = action.payload; const { controlNetId, controlImage } = action.payload;
@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({
state, state,
action: PayloadAction<{ action: PayloadAction<{
controlNetId: string; controlNetId: string;
processedControlImage: ImageDTO | null; processedControlImage: string | null;
}> }>
) => { ) => {
const { controlNetId, processedControlImage } = action.payload; const { controlNetId, processedControlImage } = action.payload;
@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({
// Preemptively remove the image from the gallery // Preemptively remove the image from the gallery
const { imageName } = action.meta.arg; const { imageName } = action.meta.arg;
forEach(state.controlNets, (c) => { forEach(state.controlNets, (c) => {
if (c.controlImage?.image_name === imageName) { if (c.controlImage === imageName) {
c.controlImage = null; c.controlImage = null;
c.processedControlImage = null; c.processedControlImage = null;
} }
if (c.processedControlImage?.image_name === imageName) { if (c.processedControlImage === imageName) {
c.processedControlImage = null; c.processedControlImage = null;
} }
}); });
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload; // const { image_name, image_url, thumbnail_url } = action.payload;
forEach(state.controlNets, (c) => { // forEach(state.controlNets, (c) => {
if (c.controlImage?.image_name === image_name) { // if (c.controlImage?.image_name === image_name) {
c.controlImage.image_url = image_url; // c.controlImage.image_url = image_url;
c.controlImage.thumbnail_url = thumbnail_url; // c.controlImage.thumbnail_url = thumbnail_url;
} // }
if (c.processedControlImage?.image_name === image_name) { // if (c.processedControlImage?.image_name === image_name) {
c.processedControlImage.image_url = image_url; // c.processedControlImage.image_url = image_url;
c.processedControlImage.thumbnail_url = thumbnail_url; // c.processedControlImage.thumbnail_url = thumbnail_url;
} // }
}); // });
}); // });
builder.addCase(appSocketInvocationError, (state, action) => { builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = []; state.pendingControlImages = [];

View File

@ -0,0 +1,27 @@
import IAIButton from 'common/components/IAIButton';
import { useCallback } from 'react';
import { useCreateBoardMutation } from 'services/apiSlice';
const DEFAULT_BOARD_NAME = 'My Board';
const AddBoardButton = () => {
const [createBoard, { isLoading }] = useCreateBoardMutation();
const handleCreateBoard = useCallback(() => {
createBoard(DEFAULT_BOARD_NAME);
}, [createBoard]);
return (
<IAIButton
isLoading={isLoading}
aria-label="Add Board"
onClick={handleCreateBoard}
size="sm"
sx={{ px: 4 }}
>
Add Board
</IAIButton>
);
};
export default AddBoardButton;

View File

@ -0,0 +1,93 @@
import { Flex, Text } from '@chakra-ui/react';
import { FaImages } from 'react-icons/fa';
import { boardIdSelected } from '../../store/boardSlice';
import { useDispatch } from 'react-redux';
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
import { AnimatePresence } from 'framer-motion';
import { SelectedItemOverlay } from '../SelectedItemOverlay';
import { useCallback } from 'react';
import { ImageDTO } from 'services/api';
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
import { useDroppable } from '@dnd-kit/core';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
const dispatch = useDispatch();
const handleAllImagesBoardClick = () => {
dispatch(boardIdSelected());
};
const [removeImageFromBoard, { isLoading }] =
useRemoveImageFromBoardMutation();
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (!droppedImage.board_id) {
return;
}
removeImageFromBoard({
board_id: droppedImage.board_id,
image_name: droppedImage.image_name,
});
},
[removeImageFromBoard]
);
const {
isOver,
setNodeRef,
active: isDropActive,
} = useDroppable({
id: `board_droppable_all_images`,
data: {
handleDrop,
},
});
return (
<Flex
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
borderRadius: 'base',
}}
onClick={handleAllImagesBoardClick}
>
<Flex
ref={setNodeRef}
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
w: 'full',
aspectRatio: '1/1',
}}
>
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaImages} />
<AnimatePresence>
{isSelected && <SelectedItemOverlay />}
</AnimatePresence>
<AnimatePresence>
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</Flex>
<Text
sx={{
color: isSelected ? 'base.50' : 'base.200',
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
}}
>
All Images
</Text>
</Flex>
);
};
export default AllImagesBoard;

View File

@ -0,0 +1,134 @@
import {
Collapse,
Flex,
Grid,
IconButton,
Input,
InputGroup,
InputRightElement,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
boardsSelector,
setBoardSearchText,
} from 'features/gallery/store/boardSlice';
import { memo, useState } from 'react';
import HoverableBoard from './HoverableBoard';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import AddBoardButton from './AddBoardButton';
import AllImagesBoard from './AllImagesBoard';
import { CloseIcon } from '@chakra-ui/icons';
import { useListAllBoardsQuery } from 'services/apiSlice';
const selector = createSelector(
[boardsSelector],
(boardsState) => {
const { selectedBoardId, searchText } = boardsState;
return { selectedBoardId, searchText };
},
defaultSelectorOptions
);
type Props = {
isOpen: boolean;
};
const BoardsList = (props: Props) => {
const { isOpen } = props;
const dispatch = useAppDispatch();
const { selectedBoardId, searchText } = useAppSelector(selector);
const { data: boards } = useListAllBoardsQuery();
const filteredBoards = searchText
? boards?.filter((board) =>
board.board_name.toLowerCase().includes(searchText.toLowerCase())
)
: boards;
const [searchMode, setSearchMode] = useState(false);
const handleBoardSearch = (searchTerm: string) => {
setSearchMode(searchTerm.length > 0);
dispatch(setBoardSearchText(searchTerm));
};
const clearBoardSearch = () => {
setSearchMode(false);
dispatch(setBoardSearchText(''));
};
return (
<Collapse in={isOpen} animateOpacity>
<Flex
sx={{
flexDir: 'column',
gap: 2,
bg: 'base.800',
borderRadius: 'base',
p: 2,
mt: 2,
}}
>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<InputGroup>
<Input
placeholder="Search Boards..."
value={searchText}
onChange={(e) => {
handleBoardSearch(e.target.value);
}}
/>
{searchText && searchText.length && (
<InputRightElement>
<IconButton
onClick={clearBoardSearch}
size="xs"
variant="ghost"
aria-label="Clear Search"
icon={<CloseIcon boxSize={3} />}
/>
</InputRightElement>
)}
</InputGroup>
<AddBoardButton />
</Flex>
<OverlayScrollbarsComponent
defer
style={{ height: '100%', width: '100%' }}
options={{
scrollbars: {
visibility: 'auto',
autoHide: 'move',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
}}
>
<Grid
className="list-container"
sx={{
gap: 2,
gridTemplateRows: '5.5rem 5.5rem',
gridAutoFlow: 'column dense',
gridAutoColumns: '4rem',
}}
>
{!searchMode && <AllImagesBoard isSelected={!selectedBoardId} />}
{filteredBoards &&
filteredBoards.map((board) => (
<HoverableBoard
key={board.board_id}
board={board}
isSelected={selectedBoardId === board.board_id}
/>
))}
</Grid>
</OverlayScrollbarsComponent>
</Flex>
</Collapse>
);
};
export default memo(BoardsList);

View File

@ -0,0 +1,193 @@
import {
Badge,
Box,
Editable,
EditableInput,
EditablePreview,
Flex,
Image,
MenuItem,
MenuList,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { memo, useCallback } from 'react';
import { FaFolder, FaTrash } from 'react-icons/fa';
import { ContextMenu } from 'chakra-ui-contextmenu';
import { BoardDTO, ImageDTO } from 'services/api';
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
import { boardIdSelected } from 'features/gallery/store/boardSlice';
import {
useAddImageToBoardMutation,
useDeleteBoardMutation,
useGetImageDTOQuery,
useUpdateBoardMutation,
} from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useDroppable } from '@dnd-kit/core';
import { AnimatePresence } from 'framer-motion';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { SelectedItemOverlay } from '../SelectedItemOverlay';
interface HoverableBoardProps {
board: BoardDTO;
isSelected: boolean;
}
const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
const dispatch = useAppDispatch();
const { data: coverImage } = useGetImageDTOQuery(
board.cover_image_name ?? skipToken
);
const { board_name, board_id } = board;
const handleSelectBoard = useCallback(() => {
dispatch(boardIdSelected(board_id));
}, [board_id, dispatch]);
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
useUpdateBoardMutation();
const [deleteBoard, { isLoading: isDeleteBoardLoading }] =
useDeleteBoardMutation();
const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] =
useAddImageToBoardMutation();
const handleUpdateBoardName = (newBoardName: string) => {
updateBoard({ board_id, changes: { board_name: newBoardName } });
};
const handleDeleteBoard = useCallback(() => {
deleteBoard(board_id);
}, [board_id, deleteBoard]);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (droppedImage.board_id === board_id) {
return;
}
addImageToBoard({ board_id, image_name: droppedImage.image_name });
},
[addImageToBoard, board_id]
);
const {
isOver,
setNodeRef,
active: isDropActive,
} = useDroppable({
id: `board_droppable_${board_id}`,
data: {
handleDrop,
},
});
return (
<Box sx={{ touchAction: 'none' }}>
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}>
<MenuItem
sx={{ color: 'error.300' }}
icon={<FaTrash />}
onClickCapture={handleDeleteBoard}
>
Delete Board
</MenuItem>
</MenuList>
)}
>
{(ref) => (
<Flex
key={board_id}
userSelect="none"
ref={ref}
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
}}
>
<Flex
ref={setNodeRef}
onClick={handleSelectBoard}
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
w: 'full',
aspectRatio: '1/1',
overflow: 'hidden',
}}
>
{board.cover_image_name && coverImage?.image_url && (
<Image src={coverImage?.image_url} draggable={false} />
)}
{!(board.cover_image_name && coverImage?.image_url) && (
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaFolder} />
)}
<Flex
sx={{
position: 'absolute',
insetInlineEnd: 0,
top: 0,
p: 1,
}}
>
<Badge variant="solid">{board.image_count}</Badge>
</Flex>
<AnimatePresence>
{isSelected && <SelectedItemOverlay />}
</AnimatePresence>
<AnimatePresence>
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</Flex>
<Box sx={{ width: 'full' }}>
<Editable
defaultValue={board_name}
submitOnBlur={false}
onSubmit={(nextValue) => {
handleUpdateBoardName(nextValue);
}}
>
<EditablePreview
sx={{
color: isSelected ? 'base.50' : 'base.200',
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
textAlign: 'center',
p: 0,
}}
noOfLines={1}
/>
<EditableInput
sx={{
color: 'base.50',
fontSize: 'xs',
borderColor: 'base.500',
p: 0,
outline: 0,
}}
/>
</Editable>
</Box>
</Flex>
)}
</ContextMenu>
</Box>
);
});
HoverableBoard.displayName = 'HoverableBoard';
export default HoverableBoard;

View File

@ -0,0 +1,93 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Box,
Flex,
Spinner,
Text,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { memo, useContext, useRef, useState } from 'react';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useListAllBoardsQuery } from 'services/apiSlice';
const UpdateImageBoardModal = () => {
// const boards = useSelector(selectBoardsAll);
const { data: boards, isFetching } = useListAllBoardsQuery();
const { isOpen, onClose, handleAddToBoard, image } = useContext(
AddImageToBoardContext
);
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const cancelRef = useRef<HTMLButtonElement>(null);
const currentBoard = boards?.find(
(board) => board.board_id === image?.board_id
);
return (
<AlertDialog
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{currentBoard ? 'Move Image to Board' : 'Add Image to Board'}
</AlertDialogHeader>
<AlertDialogBody>
<Box>
<Flex direction="column" gap={3}>
{currentBoard && (
<Text>
Moving this image from{' '}
<strong>{currentBoard.board_name}</strong> to
</Text>
)}
{isFetching ? (
<Spinner />
) : (
<IAIMantineSelect
placeholder="Select Board"
onChange={(v) => setSelectedBoard(v)}
value={selectedBoard}
data={(boards ?? []).map((board) => ({
label: board.board_name,
value: board.board_id,
}))}
/>
)}
</Flex>
</Box>
</AlertDialogBody>
<AlertDialogFooter>
<IAIButton onClick={onClose}>Cancel</IAIButton>
<IAIButton
isDisabled={!selectedBoard}
colorScheme="accent"
onClick={() => {
if (selectedBoard) {
handleAddToBoard(selectedBoard);
}
}}
ml={3}
>
{currentBoard ? 'Move' : 'Add'}
</IAIButton>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};
export default memo(UpdateImageBoardModal);

View File

@ -51,9 +51,12 @@ import { useAppToaster } from 'app/components/Toaster';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { DeleteImageContext } from 'app/contexts/DeleteImageContext'; import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
import { DeleteImageButton } from './DeleteImageModal'; import { DeleteImageButton } from './DeleteImageModal';
import { selectImagesById } from '../store/imagesSlice';
import { RootState } from 'app/store/store';
const currentImageButtonsSelector = createSelector( const currentImageButtonsSelector = createSelector(
[ [
(state: RootState) => state,
systemSelector, systemSelector,
gallerySelector, gallerySelector,
postprocessingSelector, postprocessingSelector,
@ -61,7 +64,7 @@ const currentImageButtonsSelector = createSelector(
lightboxSelector, lightboxSelector,
activeTabNameSelector, activeTabNameSelector,
], ],
(system, gallery, postprocessing, ui, lightbox, activeTabName) => { (state, system, gallery, postprocessing, ui, lightbox, activeTabName) => {
const { const {
isProcessing, isProcessing,
isConnected, isConnected,
@ -81,6 +84,8 @@ const currentImageButtonsSelector = createSelector(
shouldShowProgressInViewer, shouldShowProgressInViewer,
} = ui; } = ui;
const imageDTO = selectImagesById(state, gallery.selectedImage ?? '');
const { selectedImage } = gallery; const { selectedImage } = gallery;
return { return {
@ -97,10 +102,10 @@ const currentImageButtonsSelector = createSelector(
activeTabName, activeTabName,
isLightboxOpen, isLightboxOpen,
shouldHidePreview, shouldHidePreview,
image: selectedImage, image: imageDTO,
seed: selectedImage?.metadata?.seed, seed: imageDTO?.metadata?.seed,
prompt: selectedImage?.metadata?.positive_conditioning, prompt: imageDTO?.metadata?.positive_conditioning,
negativePrompt: selectedImage?.metadata?.negative_conditioning, negativePrompt: imageDTO?.metadata?.negative_conditioning,
shouldShowProgressInViewer, shouldShowProgressInViewer,
}; };
}, },

View File

@ -9,12 +9,12 @@ import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons'; import NextPrevImageButtons from './NextPrevImageButtons';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { configSelector } from '../../system/store/configSelectors';
import { useAppToaster } from 'app/components/Toaster';
import { imageSelected } from '../store/gallerySlice'; import { imageSelected } from '../store/gallerySlice';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { IAIImageFallback } from 'common/components/IAIImageFallback'; import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
export const imagesSelector = createSelector( export const imagesSelector = createSelector(
[uiSelector, gallerySelector, systemSelector], [uiSelector, gallerySelector, systemSelector],
@ -29,7 +29,7 @@ export const imagesSelector = createSelector(
return { return {
shouldShowImageDetails, shouldShowImageDetails,
shouldHidePreview, shouldHidePreview,
image: selectedImage, selectedImage,
progressImage, progressImage,
shouldShowProgressInViewer, shouldShowProgressInViewer,
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
@ -45,11 +45,23 @@ export const imagesSelector = createSelector(
const CurrentImagePreview = () => { const CurrentImagePreview = () => {
const { const {
shouldShowImageDetails, shouldShowImageDetails,
image, selectedImage,
progressImage, progressImage,
shouldShowProgressInViewer, shouldShowProgressInViewer,
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector); } = useAppSelector(imagesSelector);
// const image = useAppSelector((state: RootState) =>
// selectImagesById(state, selectedImage ?? '')
// );
const {
data: image,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(selectedImage ?? skipToken);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const handleDrop = useCallback( const handleDrop = useCallback(
@ -57,7 +69,7 @@ const CurrentImagePreview = () => {
if (droppedImage.image_name === image?.image_name) { if (droppedImage.image_name === image?.image_name) {
return; return;
} }
dispatch(imageSelected(droppedImage)); dispatch(imageSelected(droppedImage.image_name));
}, },
[dispatch, image?.image_name] [dispatch, image?.image_name]
); );
@ -98,14 +110,14 @@ const CurrentImagePreview = () => {
}} }}
> >
<IAIDndImage <IAIDndImage
image={image} image={selectedImage && image ? image : undefined}
onDrop={handleDrop} onDrop={handleDrop}
fallback={<IAIImageFallback sx={{ bg: 'none' }} />} fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
isUploadDisabled={true} isUploadDisabled={true}
/> />
</Flex> </Flex>
)} )}
{shouldShowImageDetails && image && ( {shouldShowImageDetails && image && selectedImage && (
<Box <Box
sx={{ sx={{
position: 'absolute', position: 'absolute',
@ -119,7 +131,7 @@ const CurrentImagePreview = () => {
<ImageMetadataViewer image={image} /> <ImageMetadataViewer image={image} />
</Box> </Box>
)} )}
{!shouldShowImageDetails && image && ( {!shouldShowImageDetails && image && selectedImage && (
<Box <Box
sx={{ sx={{
position: 'absolute', position: 'absolute',

View File

@ -2,7 +2,14 @@ import { Box, Flex, Icon, Image, MenuItem, MenuList } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useContext, useState } from 'react'; import { memo, useCallback, useContext, useState } from 'react';
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa'; import {
FaCheck,
FaExpand,
FaFolder,
FaImage,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { ContextMenu } from 'chakra-ui-contextmenu'; import { ContextMenu } from 'chakra-ui-contextmenu';
import { import {
resizeAndScaleCanvas, resizeAndScaleCanvas,
@ -27,6 +34,8 @@ import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { useDraggable } from '@dnd-kit/core'; import { useDraggable } from '@dnd-kit/core';
import { DeleteImageContext } from 'app/contexts/DeleteImageContext'; import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
export const selector = createSelector( export const selector = createSelector(
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector], [gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
@ -62,17 +71,10 @@ interface HoverableImageProps {
isSelected: boolean; isSelected: boolean;
} }
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) =>
prev.image.image_name === next.image.image_name &&
prev.isSelected === next.isSelected;
/** /**
* Gallery image component with delete/use all/use seed buttons on hover. * Gallery image component with delete/use all/use seed buttons on hover.
*/ */
const HoverableImage = memo((props: HoverableImageProps) => { const HoverableImage = (props: HoverableImageProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { const {
activeTabName, activeTabName,
@ -93,6 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onDelete } = useContext(DeleteImageContext); const { onDelete } = useContext(DeleteImageContext);
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const handleDelete = useCallback(() => { const handleDelete = useCallback(() => {
onDelete(image); onDelete(image);
}, [image, onDelete]); }, [image, onDelete]);
@ -106,11 +109,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
}, },
}); });
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleMouseOver = () => setIsHovered(true); const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false); const handleMouseOut = () => setIsHovered(false);
const handleSelectImage = useCallback(() => { const handleSelectImage = useCallback(() => {
dispatch(imageSelected(image)); dispatch(imageSelected(image.image_name));
}, [image, dispatch]); }, [image, dispatch]);
// Recall parameters handlers // Recall parameters handlers
@ -168,6 +173,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// dispatch(setIsLightboxOpen(true)); // dispatch(setIsLightboxOpen(true));
}; };
const handleAddToBoard = useCallback(() => {
onClickAddToBoard(image);
}, [image, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!image.board_id) {
return;
}
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
}, [image.board_id, image.image_name, removeFromBoard]);
const handleOpenInNewTab = () => { const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank'); window.open(image.image_url, '_blank');
}; };
@ -244,6 +260,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{t('parameters.sendToUnifiedCanvas')} {t('parameters.sendToUnifiedCanvas')}
</MenuItem> </MenuItem>
)} )}
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{image.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{image.board_id && (
<MenuItem
icon={<FaFolder />}
onClickCapture={handleRemoveFromBoard}
>
Remove from Board
</MenuItem>
)}
<MenuItem <MenuItem
sx={{ color: 'error.300' }} sx={{ color: 'error.300' }}
icon={<FaTrash />} icon={<FaTrash />}
@ -339,8 +366,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
</ContextMenu> </ContextMenu>
</Box> </Box>
); );
}, memoEqualityCheck); };
HoverableImage.displayName = 'HoverableImage'; export default memo(HoverableImage);
export default HoverableImage;

View File

@ -1,12 +1,15 @@
import { import {
Box, Box,
Button,
ButtonGroup, ButtonGroup,
Flex, Flex,
FlexProps, FlexProps,
Grid, Grid,
Icon, Icon,
Text, Text,
VStack,
forwardRef, forwardRef,
useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
@ -20,6 +23,7 @@ import {
setGalleryImageObjectFit, setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages, setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn, setShouldUseSingleGalleryColumn,
setGalleryView,
} from 'features/gallery/store/gallerySlice'; } from 'features/gallery/store/gallerySlice';
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice'; import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
import { useOverlayScrollbars } from 'overlayscrollbars-react'; import { useOverlayScrollbars } from 'overlayscrollbars-react';
@ -53,41 +57,51 @@ import {
selectImagesAll, selectImagesAll,
} from '../store/imagesSlice'; } from '../store/imagesSlice';
import { receivedPageOfImages } from 'services/thunks/image'; import { receivedPageOfImages } from 'services/thunks/image';
import BoardsList from './Boards/BoardsList';
import { boardsSelector } from '../store/boardSlice';
import { ChevronUpIcon } from '@chakra-ui/icons';
import { useListAllBoardsQuery } from 'services/apiSlice';
const categorySelector = createSelector( const itemSelector = createSelector(
[(state: RootState) => state], [(state: RootState) => state],
(state) => { (state) => {
const { images } = state; const { categories, total: allImagesTotal, isLoading } = state.images;
const { categories } = images; const { selectedBoardId } = state.boards;
const allImages = selectImagesAll(state); const allImages = selectImagesAll(state);
const filteredImages = allImages.filter((i) =>
categories.includes(i.image_category) const images = allImages.filter((i) => {
); const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = selectedBoardId
? i.board_id === selectedBoardId
: true;
return isInCategory && isInSelectedBoard;
});
return { return {
images: filteredImages, images,
isLoading: images.isLoading, allImagesTotal,
areMoreImagesAvailable: filteredImages.length < images.total, isLoading,
categories: images.categories, categories,
selectedBoardId,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
); );
const mainSelector = createSelector( const mainSelector = createSelector(
[gallerySelector, uiSelector], [gallerySelector, uiSelector, boardsSelector],
(gallery, ui) => { (gallery, ui, boards) => {
const { const {
galleryImageMinimumWidth, galleryImageMinimumWidth,
galleryImageObjectFit, galleryImageObjectFit,
shouldAutoSwitchToNewImages, shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn, shouldUseSingleGalleryColumn,
selectedImage, selectedImage,
galleryView,
} = gallery; } = gallery;
const { shouldPinGallery } = ui; const { shouldPinGallery } = ui;
return { return {
shouldPinGallery, shouldPinGallery,
galleryImageMinimumWidth, galleryImageMinimumWidth,
@ -95,6 +109,8 @@ const mainSelector = createSelector(
shouldAutoSwitchToNewImages, shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn, shouldUseSingleGalleryColumn,
selectedImage, selectedImage,
galleryView,
selectedBoardId: boards.selectedBoardId,
}; };
}, },
defaultSelectorOptions defaultSelectorOptions
@ -126,21 +142,44 @@ const ImageGalleryContent = () => {
shouldAutoSwitchToNewImages, shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn, shouldUseSingleGalleryColumn,
selectedImage, selectedImage,
galleryView,
} = useAppSelector(mainSelector); } = useAppSelector(mainSelector);
const { images, areMoreImagesAvailable, isLoading, categories } = const { images, isLoading, allImagesTotal, categories, selectedBoardId } =
useAppSelector(categorySelector); useAppSelector(itemSelector);
const { selectedBoard } = useListAllBoardsQuery(undefined, {
selectFromResult: ({ data }) => ({
selectedBoard: data?.find((b) => b.board_id === selectedBoardId),
}),
});
const filteredImagesTotal = useMemo(
() => selectedBoard?.image_count ?? allImagesTotal,
[allImagesTotal, selectedBoard?.image_count]
);
const areMoreAvailable = useMemo(() => {
return images.length < filteredImagesTotal;
}, [images.length, filteredImagesTotal]);
const handleLoadMoreImages = useCallback(() => { const handleLoadMoreImages = useCallback(() => {
dispatch(receivedPageOfImages()); dispatch(
}, [dispatch]); receivedPageOfImages({
categories,
boardId: selectedBoardId,
})
);
}, [categories, dispatch, selectedBoardId]);
const handleEndReached = useMemo(() => { const handleEndReached = useMemo(() => {
if (areMoreImagesAvailable && !isLoading) { if (areMoreAvailable && !isLoading) {
return handleLoadMoreImages; return handleLoadMoreImages;
} }
return undefined; return undefined;
}, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]); }, [areMoreAvailable, handleLoadMoreImages, isLoading]);
const { isOpen: isBoardListOpen, onToggle } = useDisclosure();
const handleChangeGalleryImageMinimumWidth = (v: number) => { const handleChangeGalleryImageMinimumWidth = (v: number) => {
dispatch(setGalleryImageMinimumWidth(v)); dispatch(setGalleryImageMinimumWidth(v));
@ -172,46 +211,79 @@ const ImageGalleryContent = () => {
const handleClickImagesCategory = useCallback(() => { const handleClickImagesCategory = useCallback(() => {
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
dispatch(setGalleryView('images'));
}, [dispatch]); }, [dispatch]);
const handleClickAssetsCategory = useCallback(() => { const handleClickAssetsCategory = useCallback(() => {
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES)); dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
dispatch(setGalleryView('assets'));
}, [dispatch]); }, [dispatch]);
return ( return (
<Flex <VStack
sx={{ sx={{
gap: 2,
flexDirection: 'column', flexDirection: 'column',
h: 'full', h: 'full',
w: 'full', w: 'full',
borderRadius: 'base', borderRadius: 'base',
}} }}
> >
<Flex <Box sx={{ w: 'full' }}>
ref={resizeObserverRef} <Flex
alignItems="center" ref={resizeObserverRef}
justifyContent="space-between" sx={{
> alignItems: 'center',
<ButtonGroup isAttached> justifyContent: 'space-between',
<IAIIconButton gap: 2,
tooltip={t('gallery.images')} }}
aria-label={t('gallery.images')} >
onClick={handleClickImagesCategory} <ButtonGroup isAttached>
isChecked={categories === IMAGE_CATEGORIES} <IAIIconButton
tooltip={t('gallery.images')}
aria-label={t('gallery.images')}
onClick={handleClickImagesCategory}
isChecked={galleryView === 'images'}
size="sm"
icon={<FaImage />}
/>
<IAIIconButton
tooltip={t('gallery.assets')}
aria-label={t('gallery.assets')}
onClick={handleClickAssetsCategory}
isChecked={galleryView === 'assets'}
size="sm"
icon={<FaServer />}
/>
</ButtonGroup>
<Flex
as={Button}
onClick={onToggle}
size="sm" size="sm"
icon={<FaImage />} variant="ghost"
/> sx={{
<IAIIconButton w: 'full',
tooltip={t('gallery.assets')} justifyContent: 'center',
aria-label={t('gallery.assets')} alignItems: 'center',
onClick={handleClickAssetsCategory} px: 2,
isChecked={categories === ASSETS_CATEGORIES} _hover: {
size="sm" bg: 'base.800',
icon={<FaServer />} },
/> }}
</ButtonGroup> >
<Flex gap={2}> <Text
noOfLines={1}
sx={{ w: 'full', color: 'base.200', fontWeight: 600 }}
>
{selectedBoard ? selectedBoard.board_name : 'All Images'}
</Text>
<ChevronUpIcon
sx={{
transform: isBoardListOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
</Flex>
<IAIPopover <IAIPopover
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
@ -269,9 +341,12 @@ const ImageGalleryContent = () => {
icon={shouldPinGallery ? <BsPinAngleFill /> : <BsPinAngle />} icon={shouldPinGallery ? <BsPinAngleFill /> : <BsPinAngle />}
/> />
</Flex> </Flex>
</Flex> <Box>
<Flex direction="column" gap={2} h="full"> <BoardsList isOpen={isBoardListOpen} />
{images.length || areMoreImagesAvailable ? ( </Box>
</Box>
<Flex direction="column" gap={2} h="full" w="full">
{images.length || areMoreAvailable ? (
<> <>
<Box ref={rootRef} data-overlayscrollbars="" h="100%"> <Box ref={rootRef} data-overlayscrollbars="" h="100%">
{shouldUseSingleGalleryColumn ? ( {shouldUseSingleGalleryColumn ? (
@ -280,14 +355,12 @@ const ImageGalleryContent = () => {
data={images} data={images}
endReached={handleEndReached} endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)} scrollerRef={(ref) => setScrollerRef(ref)}
itemContent={(index, image) => ( itemContent={(index, item) => (
<Flex sx={{ pb: 2 }}> <Flex sx={{ pb: 2 }}>
<HoverableImage <HoverableImage
key={`${image.image_name}-${image.thumbnail_url}`} key={`${item.image_name}-${item.thumbnail_url}`}
image={image} image={item}
isSelected={ isSelected={selectedImage === item?.image_name}
selectedImage?.image_name === image?.image_name
}
/> />
</Flex> </Flex>
)} )}
@ -302,13 +375,11 @@ const ImageGalleryContent = () => {
List: ListContainer, List: ListContainer,
}} }}
scrollerRef={setScroller} scrollerRef={setScroller}
itemContent={(index, image) => ( itemContent={(index, item) => (
<HoverableImage <HoverableImage
key={`${image.image_name}-${image.thumbnail_url}`} key={`${item.image_name}-${item.thumbnail_url}`}
image={image} image={item}
isSelected={ isSelected={selectedImage === item?.image_name}
selectedImage?.image_name === image?.image_name
}
/> />
)} )}
/> />
@ -316,12 +387,12 @@ const ImageGalleryContent = () => {
</Box> </Box>
<IAIButton <IAIButton
onClick={handleLoadMoreImages} onClick={handleLoadMoreImages}
isDisabled={!areMoreImagesAvailable} isDisabled={!areMoreAvailable}
isLoading={isLoading} isLoading={isLoading}
loadingText="Loading" loadingText="Loading"
flexShrink={0} flexShrink={0}
> >
{areMoreImagesAvailable {areMoreAvailable
? t('gallery.loadMore') ? t('gallery.loadMore')
: t('gallery.allImagesLoaded')} : t('gallery.allImagesLoaded')}
</IAIButton> </IAIButton>
@ -350,7 +421,7 @@ const ImageGalleryContent = () => {
</Flex> </Flex>
)} )}
</Flex> </Flex>
</Flex> </VStack>
); );
}; };

View File

@ -93,19 +93,11 @@ type ImageMetadataViewerProps = {
image: ImageDTO; image: ImageDTO;
}; };
// TODO: I don't know if this is needed.
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.image_name === next.image.image_name;
// TODO: Show more interesting information in this component.
/** /**
* Image metadata viewer overlays currently selected image and provides * Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing. * access to any of its metadata for use in processing.
*/ */
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => { const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { const {
recallBothPrompts, recallBothPrompts,
@ -333,8 +325,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
</Flex> </Flex>
</Flex> </Flex>
); );
}, memoEqualityCheck); };
ImageMetadataViewer.displayName = 'ImageMetadataViewer'; export default memo(ImageMetadataViewer);
export default ImageMetadataViewer;

View File

@ -42,7 +42,7 @@ export const nextPrevImageButtonsSelector = createSelector(
} }
const currentImageIndex = filteredImageIds.findIndex( const currentImageIndex = filteredImageIds.findIndex(
(i) => i === selectedImage.image_name (i) => i === selectedImage
); );
const nextImageIndex = clamp( const nextImageIndex = clamp(
@ -71,6 +71,8 @@ export const nextPrevImageButtonsSelector = createSelector(
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1, !isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
nextImage, nextImage,
prevImage, prevImage,
nextImageId,
prevImageId,
}; };
}, },
{ {
@ -84,7 +86,7 @@ const NextPrevImageButtons = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { isOnFirstImage, isOnLastImage, nextImage, prevImage } = const { isOnFirstImage, isOnLastImage, nextImageId, prevImageId } =
useAppSelector(nextPrevImageButtonsSelector); useAppSelector(nextPrevImageButtonsSelector);
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
@ -99,19 +101,19 @@ const NextPrevImageButtons = () => {
}, []); }, []);
const handlePrevImage = useCallback(() => { const handlePrevImage = useCallback(() => {
dispatch(imageSelected(prevImage)); dispatch(imageSelected(prevImageId));
}, [dispatch, prevImage]); }, [dispatch, prevImageId]);
const handleNextImage = useCallback(() => { const handleNextImage = useCallback(() => {
dispatch(imageSelected(nextImage)); dispatch(imageSelected(nextImageId));
}, [dispatch, nextImage]); }, [dispatch, nextImageId]);
useHotkeys( useHotkeys(
'left', 'left',
() => { () => {
handlePrevImage(); handlePrevImage();
}, },
[prevImage] [prevImageId]
); );
useHotkeys( useHotkeys(
@ -119,7 +121,7 @@ const NextPrevImageButtons = () => {
() => { () => {
handleNextImage(); handleNextImage();
}, },
[nextImage] [nextImageId]
); );
return ( return (

View File

@ -0,0 +1,26 @@
import { motion } from 'framer-motion';
export const SelectedItemOverlay = () => (
<motion.div
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
style={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
width: '100%',
height: '100%',
boxShadow: 'inset 0px 0px 0px 2px var(--invokeai-colors-accent-300)',
borderRadius: 'var(--invokeai-radii-base)',
}}
/>
);

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { selectBoardsAll } from './boardSlice';
export const boardSelector = (state: RootState) => state.boards.entities;
export const searchBoardsSelector = createSelector(
(state: RootState) => state,
(state) => {
const {
boards: { searchText },
} = state;
if (!searchText) {
// If no search text provided, return all entities
return selectBoardsAll(state);
}
return selectBoardsAll(state).filter((i) =>
i.board_name.toLowerCase().includes(searchText.toLowerCase())
);
}
);

View File

@ -0,0 +1,47 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { api } from 'services/apiSlice';
type BoardsState = {
searchText: string;
selectedBoardId?: string;
updateBoardModalOpen: boolean;
};
export const initialBoardsState: BoardsState = {
updateBoardModalOpen: false,
searchText: '',
};
const boardsSlice = createSlice({
name: 'boards',
initialState: initialBoardsState,
reducers: {
boardIdSelected: (state, action: PayloadAction<string | undefined>) => {
state.selectedBoardId = action.payload;
},
setBoardSearchText: (state, action: PayloadAction<string>) => {
state.searchText = action.payload;
},
setUpdateBoardModalOpen: (state, action: PayloadAction<boolean>) => {
state.updateBoardModalOpen = action.payload;
},
},
extraReducers: (builder) => {
builder.addMatcher(
api.endpoints.deleteBoard.matchFulfilled,
(state, action) => {
if (action.meta.arg.originalArgs === state.selectedBoardId) {
state.selectedBoardId = undefined;
}
}
);
},
});
export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } =
boardsSlice.actions;
export const boardsSelector = (state: RootState) => state.boards;
export default boardsSlice.reducer;

View File

@ -1,17 +1,16 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api';
import { imageUpserted } from './imagesSlice'; import { imageUpserted } from './imagesSlice';
import { imageUrlsReceived } from 'services/thunks/image';
type GalleryImageObjectFitType = 'contain' | 'cover'; type GalleryImageObjectFitType = 'contain' | 'cover';
export interface GalleryState { export interface GalleryState {
selectedImage?: ImageDTO; selectedImage?: string;
galleryImageMinimumWidth: number; galleryImageMinimumWidth: number;
galleryImageObjectFit: GalleryImageObjectFitType; galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean; shouldAutoSwitchToNewImages: boolean;
shouldUseSingleGalleryColumn: boolean; shouldUseSingleGalleryColumn: boolean;
galleryView: 'images' | 'assets' | 'boards';
} }
export const initialGalleryState: GalleryState = { export const initialGalleryState: GalleryState = {
@ -19,13 +18,14 @@ export const initialGalleryState: GalleryState = {
galleryImageObjectFit: 'cover', galleryImageObjectFit: 'cover',
shouldAutoSwitchToNewImages: true, shouldAutoSwitchToNewImages: true,
shouldUseSingleGalleryColumn: false, shouldUseSingleGalleryColumn: false,
galleryView: 'images',
}; };
export const gallerySlice = createSlice({ export const gallerySlice = createSlice({
name: 'gallery', name: 'gallery',
initialState: initialGalleryState, initialState: initialGalleryState,
reducers: { reducers: {
imageSelected: (state, action: PayloadAction<ImageDTO | undefined>) => { imageSelected: (state, action: PayloadAction<string | undefined>) => {
state.selectedImage = action.payload; state.selectedImage = action.payload;
// TODO: if the user selects an image, disable the auto switch? // TODO: if the user selects an image, disable the auto switch?
// state.shouldAutoSwitchToNewImages = false; // state.shouldAutoSwitchToNewImages = false;
@ -48,6 +48,12 @@ export const gallerySlice = createSlice({
) => { ) => {
state.shouldUseSingleGalleryColumn = action.payload; state.shouldUseSingleGalleryColumn = action.payload;
}, },
setGalleryView: (
state,
action: PayloadAction<'images' | 'assets' | 'boards'>
) => {
state.galleryView = action.payload;
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(imageUpserted, (state, action) => { builder.addCase(imageUpserted, (state, action) => {
@ -55,17 +61,17 @@ export const gallerySlice = createSlice({
state.shouldAutoSwitchToNewImages && state.shouldAutoSwitchToNewImages &&
action.payload.image_category === 'general' action.payload.image_category === 'general'
) { ) {
state.selectedImage = action.payload; state.selectedImage = action.payload.image_name;
} }
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { // builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload; // const { image_name, image_url, thumbnail_url } = action.payload;
if (state.selectedImage?.image_name === image_name) { // if (state.selectedImage?.image_name === image_name) {
state.selectedImage.image_url = image_url; // state.selectedImage.image_url = image_url;
state.selectedImage.thumbnail_url = thumbnail_url; // state.selectedImage.thumbnail_url = thumbnail_url;
} // }
}); // });
}, },
}); });
@ -75,6 +81,7 @@ export const {
setGalleryImageObjectFit, setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages, setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn, setShouldUseSingleGalleryColumn,
setGalleryView,
} = gallerySlice.actions; } = gallerySlice.actions;
export default gallerySlice.reducer; export default gallerySlice.reducer;

View File

@ -11,7 +11,6 @@ import { dateComparator } from 'common/util/dateComparator';
import { keyBy } from 'lodash-es'; import { keyBy } from 'lodash-es';
import { import {
imageDeleted, imageDeleted,
imageMetadataReceived,
imageUrlsReceived, imageUrlsReceived,
receivedPageOfImages, receivedPageOfImages,
} from 'services/thunks/image'; } from 'services/thunks/image';
@ -74,11 +73,21 @@ const imagesSlice = createSlice({
}); });
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => { builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
state.isLoading = false; state.isLoading = false;
const { boardId, categories, imageOrigin, isIntermediate } =
action.meta.arg;
const { items, offset, limit, total } = action.payload; const { items, offset, limit, total } = action.payload;
imagesAdapter.upsertMany(state, items);
if (!categories?.includes('general') || boardId) {
// need to skip updating the total images count if the images recieved were for a specific board
// TODO: this doesn't work when on the Asset tab/category...
return;
}
state.offset = offset; state.offset = offset;
state.limit = limit; state.limit = limit;
state.total = total; state.total = total;
imagesAdapter.upsertMany(state, items);
}); });
builder.addCase(imageDeleted.pending, (state, action) => { builder.addCase(imageDeleted.pending, (state, action) => {
// Image deleted // Image deleted
@ -154,3 +163,16 @@ export const selectFilteredImagesIds = createSelector(
.map((i) => i.image_name); .map((i) => i.image_name);
} }
); );
// export const selectImageById = createSelector(
// (state: RootState, imageId) => state,
// (state) => {
// const {
// images: { categories },
// } = state;
// return selectImagesAll(state)
// .filter((i) => categories.includes(i.image_category))
// .map((i) => i.image_name);
// }
// );

View File

@ -11,6 +11,8 @@ import { FieldComponentProps } from './types';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const ImageInputFieldComponent = ( const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate> props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
@ -19,9 +21,16 @@ const ImageInputFieldComponent = (
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const {
data: image,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(field.value ?? skipToken);
const handleDrop = useCallback( const handleDrop = useCallback(
(droppedImage: ImageDTO) => { (droppedImage: ImageDTO) => {
if (field.value?.image_name === droppedImage.image_name) { if (field.value === droppedImage.image_name) {
return; return;
} }
@ -29,11 +38,11 @@ const ImageInputFieldComponent = (
fieldValueChanged({ fieldValueChanged({
nodeId, nodeId,
fieldName: field.name, fieldName: field.name,
value: droppedImage, value: droppedImage.image_name,
}) })
); );
}, },
[dispatch, field.name, field.value?.image_name, nodeId] [dispatch, field.name, field.value, nodeId]
); );
const handleReset = useCallback(() => { const handleReset = useCallback(() => {
@ -56,7 +65,7 @@ const ImageInputFieldComponent = (
}} }}
> >
<IAIDndImage <IAIDndImage
image={field.value} image={image}
onDrop={handleDrop} onDrop={handleDrop}
onReset={handleReset} onReset={handleReset}
resetIconSize="sm" resetIconSize="sm"

View File

@ -1,28 +1,18 @@
import { Select } from '@chakra-ui/react'; import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
ModelInputFieldTemplate, ModelInputFieldTemplate,
ModelInputFieldValue, ModelInputFieldValue,
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { selectModelsIds } from 'features/system/store/modelSlice';
import { isEqual } from 'lodash-es';
import { ChangeEvent, memo } from 'react';
import { FieldComponentProps } from './types';
const availableModelsSelector = createSelector( import { memo, useCallback, useEffect, useMemo } from 'react';
[selectModelsIds], import { FieldComponentProps } from './types';
(allModelNames) => { import { forEach, isString } from 'lodash-es';
return { allModelNames }; import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
// return map(modelList, (_, name) => name); import IAIMantineSelect from 'common/components/IAIMantineSelect';
}, import { useTranslation } from 'react-i18next';
{ import { useListModelsQuery } from 'services/apiSlice';
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ModelInputFieldComponent = ( const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate> props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
@ -30,28 +20,82 @@ const ModelInputFieldComponent = (
const { nodeId, field } = props; const { nodeId, field } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation();
const { allModelNames } = useAppSelector(availableModelsSelector); const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => { const data = useMemo(() => {
dispatch( if (!pipelineModels) {
fieldValueChanged({ return [];
nodeId, }
fieldName: field.name,
value: e.target.value, const data: SelectItem[] = [];
})
); forEach(pipelineModels.entities, (model, id) => {
}; if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
return data;
}, [pipelineModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
[pipelineModels?.entities, pipelineModels?.ids, field.value]
);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && pipelineModels?.ids.includes(field.value)) {
return;
}
const firstModel = pipelineModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleValueChanged(firstModel);
}, [field.value, handleValueChanged, pipelineModels?.ids]);
return ( return (
<Select <IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged} onChange={handleValueChanged}
value={field.value || allModelNames[0]} />
>
{allModelNames.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
); );
}; };

View File

@ -101,21 +101,6 @@ const nodesSlice = createSlice({
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload; state.schema = action.payload;
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
state.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => {
if (input.type === 'image') {
if (input.value?.image_name === image_name) {
input.value.image_url = image_url;
input.value.thumbnail_url = thumbnail_url;
}
}
});
});
});
}, },
}); });

View File

@ -214,7 +214,7 @@ export type VaeInputFieldValue = FieldValueBase & {
export type ImageInputFieldValue = FieldValueBase & { export type ImageInputFieldValue = FieldValueBase & {
type: 'image'; type: 'image';
value?: ImageDTO; value?: string;
}; };
export type ModelInputFieldValue = FieldValueBase & { export type ModelInputFieldValue = FieldValueBase & {

View File

@ -65,15 +65,13 @@ export const addControlNetToLinearGraph = (
if (processedControlImage && processorType !== 'none') { if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image // We've already processed the image in the app, so we can just use the processed image
const { image_name } = processedControlImage;
controlNetNode.image = { controlNetNode.image = {
image_name, image_name: processedControlImage,
}; };
} else if (controlImage) { } else if (controlImage) {
// The control image is preprocessed // The control image is preprocessed
const { image_name } = controlImage;
controlNetNode.image = { controlNetNode.image = {
image_name, image_name: controlImage,
}; };
} else { } else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly // Skip ControlNets without an unprocessed image - should never happen if everything is working correctly

View File

@ -23,6 +23,7 @@ import {
} from './constants'; } from './constants';
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = (
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -17,6 +17,7 @@ import {
INPAINT_GRAPH, INPAINT_GRAPH,
INPAINT, INPAINT,
} from './constants'; } from './constants';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image // We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas; const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToPipelineModelField(modelId);
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: INPAINT_GRAPH, id: INPAINT_GRAPH,
nodes: { nodes: {
@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = (
prompt: negativePrompt, prompt: negativePrompt,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[RANGE_OF_SIZE]: { [RANGE_OF_SIZE]: {
type: 'range_of_size', type: 'range_of_size',

View File

@ -14,6 +14,7 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params // The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions; const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = (
steps, steps,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -22,6 +22,7 @@ import {
} from './constants'; } from './constants';
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state'); throw new Error('No initial image found in state');
} }
const model = modelIdToPipelineModelField(modelId);
// copy-pasted graph from node editor, filled in with state values & friendly node ids // copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = { const graph: NonNullableGraph = {
id: IMAGE_TO_IMAGE_GRAPH, id: IMAGE_TO_IMAGE_GRAPH,
@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = (
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
@ -274,7 +277,7 @@ export const buildLinearImageToImageGraph = (
id: RESIZE, id: RESIZE,
type: 'img_resize', type: 'img_resize',
image: { image: {
image_name: initialImage.image_name, image_name: initialImage.imageName,
}, },
is_intermediate: true, is_intermediate: true,
width, width,
@ -311,7 +314,7 @@ export const buildLinearImageToImageGraph = (
} else { } else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', { set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
image_name: initialImage.image_name, image_name: initialImage.imageName,
}); });
// Pass the image's dimensions to the `NOISE` node // Pass the image's dimensions to the `NOISE` node

View File

@ -1,6 +1,10 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api'; import {
BaseModelType,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api';
import { import {
ITERATE, ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
@ -14,6 +18,7 @@ import {
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
type TextToImageGraphOverrides = { type TextToImageGraphOverrides = {
width: number; width: number;
@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = (
const { const {
positivePrompt, positivePrompt,
negativePrompt, negativePrompt,
model: model_name, model: modelId,
cfgScale: cfg_scale, cfgScale: cfg_scale,
scheduler, scheduler,
steps, steps,
@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = (
shouldRandomizeSeed, shouldRandomizeSeed,
} = state.generation; } = state.generation;
const model = modelIdToPipelineModelField(modelId);
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = (
steps, steps,
}, },
[MODEL_LOADER]: { [MODEL_LOADER]: {
type: 'sd1_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: MODEL_LOADER,
model_name, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',

View File

@ -1,9 +1,10 @@
import { Graph } from 'services/api'; import { Graph } from 'services/api';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es'; import { cloneDeep, omit, reduce } from 'lodash-es';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types'; import { InputFieldValue } from 'features/nodes/types/types';
import { AnyInvocation } from 'services/events/types'; import { AnyInvocation } from 'services/events/types';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/** /**
* We need to do special handling for some fields * We need to do special handling for some fields
@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
} }
} }
if (field.type === 'model') {
if (field.value) {
return modelIdToPipelineModelField(field.value);
}
}
return field.value; return field.value;
}; };

View File

@ -7,7 +7,7 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int'; export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const MODEL_LOADER = 'model_loader'; export const MODEL_LOADER = 'pipeline_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';

View File

@ -0,0 +1,18 @@
import { BaseModelType, PipelineModelField } from 'services/api';
/**
* Crudely converts a model id to a pipeline model field
* TODO: Make better
*/
export const modelIdToPipelineModelField = (
modelId: string
): PipelineModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: PipelineModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -57,7 +57,7 @@ export const buildImg2ImgNode = (
} }
imageToImageNode.image = { imageToImageNode.image = {
image_name: initialImage.image_name, image_name: initialImage.imageName,
}; };
} }

View File

@ -6,7 +6,7 @@ import ParamScheduler from './ParamScheduler';
const ParamSchedulerAndModel = () => { const ParamSchedulerAndModel = () => {
return ( return (
<Flex gap={3} w="full"> <Flex gap={3} w="full">
<Box w="20rem"> <Box w="25rem">
<ParamScheduler /> <ParamScheduler />
</Box> </Box>
<Box w="full"> <Box w="full">

View File

@ -10,7 +10,9 @@ import { generationSelector } from 'features/parameters/store/generationSelector
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage'; import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { IAIImageFallback } from 'common/components/IAIImageFallback'; import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector( const selector = createSelector(
[generationSelector], [generationSelector],
@ -27,14 +29,21 @@ const InitialImagePreview = () => {
const { initialImage } = useAppSelector(selector); const { initialImage } = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const {
data: image,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(initialImage?.imageName ?? skipToken);
const handleDrop = useCallback( const handleDrop = useCallback(
(droppedImage: ImageDTO) => { (droppedImage: ImageDTO) => {
if (droppedImage.image_name === initialImage?.image_name) { if (droppedImage.image_name === initialImage?.imageName) {
return; return;
} }
dispatch(initialImageChanged(droppedImage)); dispatch(initialImageChanged(droppedImage));
}, },
[dispatch, initialImage?.image_name] [dispatch, initialImage]
); );
const handleReset = useCallback(() => { const handleReset = useCallback(() => {
@ -53,10 +62,10 @@ const InitialImagePreview = () => {
}} }}
> >
<IAIDndImage <IAIDndImage
image={initialImage} image={image}
onDrop={handleDrop} onDrop={handleDrop}
onReset={handleReset} onReset={handleReset}
fallback={<IAIImageFallback sx={{ bg: 'none' }} />} fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
postUploadAction={{ type: 'SET_INITIAL_IMAGE' }} postUploadAction={{ type: 'SET_INITIAL_IMAGE' }}
withResetIcon withResetIcon
/> />

View File

@ -1,10 +1,9 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
import { configChanged } from 'features/system/store/configSlice'; import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es'; import { clamp } from 'lodash-es';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { imageUrlsReceived } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { import {
CfgScaleParam, CfgScaleParam,
HeightParam, HeightParam,
@ -17,14 +16,13 @@ import {
StrengthParam, StrengthParam,
WidthParam, WidthParam,
} from './parameterZodSchemas'; } from './parameterZodSchemas';
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
export interface GenerationState { export interface GenerationState {
cfgScale: CfgScaleParam; cfgScale: CfgScaleParam;
height: HeightParam; height: HeightParam;
img2imgStrength: StrengthParam; img2imgStrength: StrengthParam;
infillMethod: string; infillMethod: string;
initialImage?: ImageDTO; initialImage?: { imageName: string; width: number; height: number };
iterations: number; iterations: number;
perlin: number; perlin: number;
positivePrompt: PositivePromptParam; positivePrompt: PositivePromptParam;
@ -212,35 +210,20 @@ export const generationSlice = createSlice({
state.shouldUseNoiseSettings = action.payload; state.shouldUseNoiseSettings = action.payload;
}, },
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => { initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
state.initialImage = action.payload; const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height };
}, },
modelSelected: (state, action: PayloadAction<string>) => { modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload; state.model = action.payload;
}, },
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (!state.model) {
const firstModel = sortBy(action.payload, 'name')[0];
state.model = firstModel.name;
}
});
builder.addCase(configChanged, (state, action) => { builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel; const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) { if (defaultModel && !state.model) {
state.model = defaultModel; state.model = defaultModel;
} }
}); });
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
if (state.initialImage?.image_name === image_name) {
state.initialImage.image_url = image_url;
state.initialImage.thumbnail_url = thumbnail_url;
}
});
}, },
}); });

View File

@ -154,3 +154,17 @@ export type StrengthParam = z.infer<typeof zStrength>;
*/ */
export const isValidStrength = (val: unknown): val is StrengthParam => export const isValidStrength = (val: unknown): val is StrengthParam =>
zStrength.safeParse(val).success; zStrength.safeParse(val).success;
// /**
// * Zod schema for BaseModelType
// */
// export const zBaseModelType = z.enum(['sd-1', 'sd-2']);
// /**
// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI.
// */
// export type BaseModelType = z.infer<typeof zBaseModelType>;
// /**
// * Validates/type-guards a value as a base model type
// */
// export const isValidBaseModelType = (val: unknown): val is BaseModelType =>
// zBaseModelType.safeParse(val).success;

View File

@ -1,44 +1,59 @@
import { createSelector } from '@reduxjs/toolkit'; import { memo, useCallback, useEffect, useMemo } from 'react';
import { isEqual } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, { import IAIMantineSelect from 'common/components/IAIMantineSelect';
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { modelSelected } from 'features/parameters/store/generationSlice'; import { modelSelected } from 'features/parameters/store/generationSlice';
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
const selector = createSelector( import { forEach, isString } from 'lodash-es';
[(state: RootState) => state, generationSelector], import { SelectItem } from '@mantine/core';
(state, generation) => { import { RootState } from 'app/store/store';
const selectedModel = selectModelsById(state, generation.model); import { useListModelsQuery } from 'services/apiSlice';
const modelData = selectModelsAll(state) export const MODEL_TYPE_MAP = {
.map<IAISelectDataType>((m) => ({ 'sd-1': 'Stable Diffusion 1.x',
value: m.name, 'sd-2': 'Stable Diffusion 2.x',
label: m.name, };
}))
.sort((a, b) => a.label.localeCompare(b.label));
return {
selectedModel,
modelData,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
const ModelSelect = () => { const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { selectedModel, modelData } = useAppSelector(selector);
const selectedModelId = useAppSelector(
(state: RootState) => state.generation.model
);
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const data = useMemo(() => {
if (!pipelineModels) {
return [];
}
const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [pipelineModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[selectedModelId],
[pipelineModels?.entities, selectedModelId]
);
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
(v: string | null) => { (v: string | null) => {
if (!v) { if (!v) {
@ -49,13 +64,27 @@ const ModelSelect = () => {
[dispatch] [dispatch]
); );
useEffect(() => {
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
return;
}
const firstModel = pipelineModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleChangeModel(firstModel);
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
return ( return (
<IAIMantineSelect <IAIMantineSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.model')} label={t('modelManager.model')}
value={selectedModel?.name ?? ''} value={selectedModelId}
placeholder="Pick one" placeholder="Pick one"
data={modelData} data={data}
onChange={handleChangeModel} onChange={handleChangeModel}
/> />
); );

View File

@ -1,6 +1,5 @@
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({
export default function SettingsSchedulers() { export default function SettingsSchedulers() {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const enabledSchedulers = useAppSelector( const enabledSchedulers = useAppSelector(

View File

@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors';
const isApplicationReadySelector = createSelector( const isApplicationReadySelector = createSelector(
[systemSelector, configSelector], [systemSelector, configSelector],
(system, config) => { (system, config) => {
const { wereModelsReceived, wasSchemaParsed } = system; const { wasSchemaParsed } = system;
const { disabledTabs } = config; const { disabledTabs } = config;
return { return {
disabledTabs, disabledTabs,
wereModelsReceived,
wasSchemaParsed, wasSchemaParsed,
}; };
} }
@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector(
* Checks if the application is ready to be used, i.e. if the initial startup process is finished. * Checks if the application is ready to be used, i.e. if the initial startup process is finished.
*/ */
export const useIsApplicationReady = () => { export const useIsApplicationReady = () => {
const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector( const { disabledTabs, wasSchemaParsed } = useAppSelector(
isApplicationReadySelector isApplicationReadySelector
); );
const isApplicationReady = useMemo(() => { const isApplicationReady = useMemo(() => {
if (!wereModelsReceived) {
return false;
}
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) { if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
return false; return false;
} }
return true; return true;
}, [disabledTabs, wereModelsReceived, wasSchemaParsed]); }, [disabledTabs, wasSchemaParsed]);
return isApplicationReady; return isApplicationReady;
}; };

View File

@ -1,3 +0,0 @@
import { RootState } from 'app/store/store';
export const modelSelector = (state: RootState) => state.models;

View File

@ -1,47 +0,0 @@
import { createEntityAdapter } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
import { receivedModels } from 'services/thunks/model';
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
name: string;
};
export const modelsAdapter = createEntityAdapter<Model>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const initialModelsState = modelsAdapter.getInitialState();
export type ModelsState = typeof initialModelsState;
export const modelsSlice = createSlice({
name: 'models',
initialState: initialModelsState,
reducers: {
modelAdded: modelsAdapter.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
const models = action.payload;
modelsAdapter.setAll(state, models);
});
},
});
export const {
selectAll: selectModelsAll,
selectById: selectModelsById,
selectEntities: selectModelsEntities,
selectIds: selectModelsIds,
selectTotal: selectModelsTotal,
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
export const { modelAdded } = modelsSlice.actions;
export default modelsSlice.reducer;

View File

@ -1,6 +0,0 @@
import { ModelsState } from './modelSlice';
/**
* Models slice persist denylist
*/
export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids'];

View File

@ -1,20 +1,12 @@
import { UseToastOptions } from '@chakra-ui/react'; import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction } from '@reduxjs/toolkit'; import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai'; import * as InvokeAI from 'app/types/invokeai';
import { ProgressImage } from 'services/events/types';
import { makeToast } from '../../../app/components/Toaster';
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
import { receivedModels } from 'services/thunks/model';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { LogLevelName } from 'roarr';
import { InvokeLogLevel } from 'app/logging/useLogger'; import { InvokeLogLevel } from 'app/logging/useLogger';
import { TFuncKey } from 'i18next';
import { t } from 'i18next';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { LANGUAGES } from '../components/LanguagePicker'; import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { imageUploaded } from 'services/thunks/image'; import { TFuncKey, t } from 'i18next';
import { LogLevelName } from 'roarr';
import { import {
appSocketConnected, appSocketConnected,
appSocketDisconnected, appSocketDisconnected,
@ -26,6 +18,11 @@ import {
appSocketSubscribed, appSocketSubscribed,
appSocketUnsubscribed, appSocketUnsubscribed,
} from 'services/events/actions'; } from 'services/events/actions';
import { ProgressImage } from 'services/events/types';
import { imageUploaded } from 'services/thunks/image';
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -95,6 +92,7 @@ export interface SystemState {
shouldAntialiasProgressImage: boolean; shouldAntialiasProgressImage: boolean;
language: keyof typeof LANGUAGES; language: keyof typeof LANGUAGES;
isUploading: boolean; isUploading: boolean;
boardIdToAddTo?: string;
} }
export const initialSystemState: SystemState = { export const initialSystemState: SystemState = {
@ -225,6 +223,7 @@ export const systemSlice = createSlice({
*/ */
builder.addCase(appSocketSubscribed, (state, action) => { builder.addCase(appSocketSubscribed, (state, action) => {
state.sessionId = action.payload.sessionId; state.sessionId = action.payload.sessionId;
state.boardIdToAddTo = action.payload.boardId;
state.canceledSession = ''; state.canceledSession = '';
}); });
@ -233,6 +232,7 @@ export const systemSlice = createSlice({
*/ */
builder.addCase(appSocketUnsubscribed, (state) => { builder.addCase(appSocketUnsubscribed, (state) => {
state.sessionId = null; state.sessionId = null;
state.boardIdToAddTo = undefined;
}); });
/** /**
@ -376,13 +376,6 @@ export const systemSlice = createSlice({
); );
}); });
/**
* Received available models from the backend
*/
builder.addCase(receivedModels.fulfilled, (state) => {
state.wereModelsReceived = true;
});
/** /**
* OpenAPI schema was parsed * OpenAPI schema was parsed
*/ */

View File

@ -8,6 +8,10 @@ export type { OpenAPIConfig } from './core/OpenAPI';
export type { AddInvocation } from './models/AddInvocation'; export type { AddInvocation } from './models/AddInvocation';
export type { BaseModelType } from './models/BaseModelType'; export type { BaseModelType } from './models/BaseModelType';
export type { BoardChanges } from './models/BoardChanges';
export type { BoardDTO } from './models/BoardDTO';
export type { Body_create_board_image } from './models/Body_create_board_image';
export type { Body_remove_board_image } from './models/Body_remove_board_image';
export type { Body_upload_image } from './models/Body_upload_image'; export type { Body_upload_image } from './models/Body_upload_image';
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation'; export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
export type { CkptModelInfo } from './models/CkptModelInfo'; export type { CkptModelInfo } from './models/CkptModelInfo';
@ -21,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField';
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation'; export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
export type { ControlField } from './models/ControlField'; export type { ControlField } from './models/ControlField';
export type { ControlNetInvocation } from './models/ControlNetInvocation'; export type { ControlNetInvocation } from './models/ControlNetInvocation';
export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
export type { ControlNetModelFormat } from './models/ControlNetModelFormat';
export type { ControlOutput } from './models/ControlOutput'; export type { ControlOutput } from './models/ControlOutput';
export type { CreateModelRequest } from './models/CreateModelRequest'; export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation'; export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
@ -63,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
export type { InpaintInvocation } from './models/InpaintInvocation'; export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput'; export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput'; export type { IntOutput } from './models/IntOutput';
export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config';
export type { IterateInvocation } from './models/IterateInvocation'; export type { IterateInvocation } from './models/IterateInvocation';
export type { IterateInvocationOutput } from './models/IterateInvocationOutput'; export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
export type { LatentsField } from './models/LatentsField'; export type { LatentsField } from './models/LatentsField';
@ -83,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { LoraInfo } from './models/LoraInfo'; export type { LoraInfo } from './models/LoraInfo';
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation'; export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
export type { LoraLoaderOutput } from './models/LoraLoaderOutput'; export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
export type { LoRAModelConfig } from './models/LoRAModelConfig';
export type { LoRAModelFormat } from './models/LoRAModelFormat';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation'; export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput'; export type { MaskOutput } from './models/MaskOutput';
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation'; export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
@ -98,12 +98,15 @@ export type { MultiplyInvocation } from './models/MultiplyInvocation';
export type { NoiseInvocation } from './models/NoiseInvocation'; export type { NoiseInvocation } from './models/NoiseInvocation';
export type { NoiseOutput } from './models/NoiseOutput'; export type { NoiseOutput } from './models/NoiseOutput';
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation'; export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
export type { OffsetPaginatedResults_BoardDTO_ } from './models/OffsetPaginatedResults_BoardDTO_';
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_'; export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation'; export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_'; export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
export type { ParamFloatInvocation } from './models/ParamFloatInvocation'; export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation'; export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation'; export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PipelineModelField } from './models/PipelineModelField';
export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation';
export type { PromptCollectionOutput } from './models/PromptCollectionOutput'; export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
export type { PromptOutput } from './models/PromptOutput'; export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation'; export type { RandomIntInvocation } from './models/RandomIntInvocation';
@ -115,20 +118,28 @@ export type { ResourceOrigin } from './models/ResourceOrigin';
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation'; export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation'; export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
export type { SchedulerPredictionType } from './models/SchedulerPredictionType'; export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
export type { ShowImageInvocation } from './models/ShowImageInvocation'; export type { ShowImageInvocation } from './models/ShowImageInvocation';
export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat';
export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat';
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation'; export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
export type { SubModelType } from './models/SubModelType'; export type { SubModelType } from './models/SubModelType';
export type { SubtractInvocation } from './models/SubtractInvocation'; export type { SubtractInvocation } from './models/SubtractInvocation';
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation'; export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
export type { UNetField } from './models/UNetField'; export type { UNetField } from './models/UNetField';
export type { UpscaleInvocation } from './models/UpscaleInvocation'; export type { UpscaleInvocation } from './models/UpscaleInvocation';
export type { VaeField } from './models/VaeField'; export type { VaeField } from './models/VaeField';
export type { VaeModelConfig } from './models/VaeModelConfig';
export type { VaeModelFormat } from './models/VaeModelFormat';
export type { VaeRepo } from './models/VaeRepo'; export type { VaeRepo } from './models/VaeRepo';
export type { ValidationError } from './models/ValidationError'; export type { ValidationError } from './models/ValidationError';
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation'; export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
export { BoardsService } from './services/BoardsService';
export { ImagesService } from './services/ImagesService'; export { ImagesService } from './services/ImagesService';
export { ModelsService } from './services/ModelsService'; export { ModelsService } from './services/ModelsService';
export { SessionsService } from './services/SessionsService'; export { SessionsService } from './services/SessionsService';

View File

@ -0,0 +1,15 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type BoardChanges = {
/**
* The board's new name.
*/
board_name?: string;
/**
* The name of the board's new cover image.
*/
cover_image_name?: string;
};

View File

@ -0,0 +1,38 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Deserialized board record with cover image URL and image count.
*/
export type BoardDTO = {
/**
* The unique ID of the board.
*/
board_id: string;
/**
* The name of the board.
*/
board_name: string;
/**
* The created timestamp of the board.
*/
created_at: string;
/**
* The updated timestamp of the board.
*/
updated_at: string;
/**
* The deleted timestamp of the board.
*/
deleted_at?: string;
/**
* The name of the board's cover image.
*/
cover_image_name?: string;
/**
* The number of images in the board.
*/
image_count: number;
};

View File

@ -0,0 +1,15 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Body_create_board_image = {
/**
* The id of the board to add to
*/
board_id: string;
/**
* The name of the image to add
*/
image_name: string;
};

View File

@ -0,0 +1,15 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Body_remove_board_image = {
/**
* The id of the board
*/
board_id: string;
/**
* The name of the image to remove
*/
image_name: string;
};

View File

@ -0,0 +1,18 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ControlNetModelFormat } from './ControlNetModelFormat';
import type { ModelError } from './ModelError';
export type ControlNetModelConfig = {
name: string;
base_model: BaseModelType;
type: 'controlnet';
path: string;
description?: string;
model_format: ControlNetModelFormat;
error?: ModelError;
};

View File

@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* An enumeration.
*/
export type ControlNetModelFormat = 'checkpoint' | 'diffusers';

Some files were not shown because too many files have changed in this diff Show More