Merge branch 'main' into install-script-python-version-error-prompt-fix

This commit is contained in:
Lincoln Stein 2023-06-23 02:15:36 +01:00 committed by GitHub
commit df1907e849
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
137 changed files with 4087 additions and 1018 deletions

View File

@ -2,8 +2,17 @@
from logging import Logger
import os
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
BoardImagesService,
BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
@ -57,7 +66,7 @@ class ApiDependencies:
# TODO: build a file/path manager?
db_location = config.db_path
db_location.parent.mkdir(parents=True,exist_ok=True)
db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
@ -72,14 +81,40 @@ class ApiDependencies:
DiskLatentsStorage(f"{output_folder}/latents")
)
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
boards = BoardService(
services=BoardServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
board_images = BoardImagesService(
services=BoardImagesServiceDependencies(
board_image_record_storage=board_image_record_storage,
board_record_storage=board_record_storage,
image_record_storage=image_record_storage,
url=urls,
logger=logger,
)
)
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
services=ImageServiceDependencies(
board_image_record_storage=board_image_record_storage,
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
)
services = InvocationServices(
@ -87,6 +122,8 @@ class ApiDependencies:
events=events,
latents=latents,
images=images,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"

View File

@ -0,0 +1,69 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO
from ..dependencies import ApiDependencies
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
@board_images_router.post(
"/",
operation_id="create_board_image",
responses={
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
)
async def create_board_image(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
):
"""Creates a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add to board")
@board_images_router.delete(
"/",
operation_id="remove_board_image",
responses={
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
)
async def remove_board_image(
board_id: str = Body(description="The id of the board"),
image_name: str = Body(description="The name of the image to remove"),
):
"""Deletes a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
@board_images_router.get(
"/{board_id}",
operation_id="list_board_images",
response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_board_images(
board_id: str = Path(description="The id of the board"),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of boards per page"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images for a board"""
results = ApiDependencies.invoker.services.board_images.get_images_for_board(
board_id,
)
return results

View File

@ -0,0 +1,108 @@
from typing import Optional, Union
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.board_record_storage import BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from ..dependencies import ApiDependencies
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@boards_router.post(
"/",
operation_id="create_board",
responses={
201: {"description": "The board was created successfully"},
},
status_code=201,
response_model=BoardDTO,
)
async def create_board(
board_name: str = Query(description="The name of the board to create"),
) -> BoardDTO:
"""Creates a board"""
try:
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to create board")
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
async def get_board(
board_id: str = Path(description="The id of board to get"),
) -> BoardDTO:
"""Gets a board"""
try:
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
return result
except Exception as e:
raise HTTPException(status_code=404, detail="Board not found")
@boards_router.patch(
"/{board_id}",
operation_id="update_board",
responses={
201: {
"description": "The board was updated successfully",
},
},
status_code=201,
response_model=BoardDTO,
)
async def update_board(
board_id: str = Path(description="The id of board to update"),
changes: BoardChanges = Body(description="The changes to apply to the board"),
) -> BoardDTO:
"""Updates a board"""
try:
result = ApiDependencies.invoker.services.boards.update(
board_id=board_id, changes=changes
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
@boards_router.delete("/{board_id}", operation_id="delete_board")
async def delete_board(
board_id: str = Path(description="The id of board to delete"),
) -> None:
"""Deletes a board"""
try:
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@boards_router.get(
"/",
operation_id="list_boards",
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
)
async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(
default=None, description="The number of boards per page"
),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all()
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
else:
raise HTTPException(
status_code=400,
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
)

View File

@ -221,6 +221,9 @@ async def list_images_with_metadata(
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
board_id: Optional[str] = Query(
default=None, description="The board id to filter by"
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]:
@ -232,6 +235,7 @@ async def list_images_with_metadata(
image_origin,
categories,
is_intermediate,
board_id,
)
return image_dtos

View File

@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import get_all_model_configs
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
models: list[MODEL_CONFIGS]
@models_router.get(
@ -72,10 +71,10 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }},
)
async def list_models(
base_model: BaseModelType = Query(
base_model: Optional[BaseModelType] = Query(
default=None, description="Base model"
),
model_type: ModelType = Query(
model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get"
),
) -> ModelsList:

View File

@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config)
import invokeai.frontend.web as web_dir
from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images
from .api.routers import sessions, models, images, boards, board_images
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi():
@ -116,6 +120,22 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref
from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__
if name in openapi_schema["components"]["schemas"]:
# print(f"Config with name {name} already defined")
continue
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
openapi_schema["components"]["schemas"][name] = dict(
title=name,
description="An enumeration.",
type="string",
enum=list(v.value for v in model_config_format_enum),
)
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on
class SD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
class PipelineModelField(BaseModel):
"""Pipeline model field"""
type: Literal["sd1_model_loader"] = "sd1_model_loader"
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_name: str = Field(default="", description="Model to load")
class PipelineModelLoaderInvocation(BaseInvocation):
"""Loads a pipeline model, outputting its submodels."""
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
model: PipelineModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation):
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
"model": "model"
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion1 # TODO:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Pipeline
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
):
raise Exception(f"Unkown model name: {self.model_name}!")
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation):
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
model_name=model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Vae,
),
)
)
# TODO: optimize(less code copy)
class SD2ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd2_model_loader"] = "sd2_model_loader"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion2 # TODO:
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
):
raise Exception(f"Unkown model name: {self.model_name}!")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
model_type=model_type,
submodel=SubModelType.Vae,
),
)

View File

@ -0,0 +1,254 @@
from abc import ABC, abstractmethod
import sqlite3
import threading
from typing import Union, cast
from invokeai.app.services.board_record_storage import BoardRecord
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageRecord,
deserialize_image_record,
)
class BoardImageRecordStorageBase(ABC):
"""Abstract base class for the one-to-many board-image relationship record storage."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images for a board."""
pass
@abstractmethod
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
@abstractmethod
def get_image_count_for_board(
self,
board_id: str,
) -> int:
"""Gets the number of images for a board."""
pass
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `board_images` junction table."""
# Create the `board_images` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS board_images (
board_id TEXT NOT NULL,
image_name TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between boards and images using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (image_name),
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
);
"""
)
# Add index for board id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
"""
)
# Add index for board id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
AFTER UPDATE
ON board_images FOR EACH ROW
BEGIN
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE board_id = old.board_id AND image_name = old.image_name;
END;
"""
)
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE board_id = ? AND image_name = ?;
""",
(board_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_images_for_board(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
# TODO: this isn't paginated yet?
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_image_count_for_board(self, board_id: str) -> int:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
""",
(board_id,),
)
count = cast(int, self._cursor.fetchone()[0])
return count
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -0,0 +1,142 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import List, Union
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_record_storage import (
BoardRecord,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
from invokeai.app.services.urls import UrlServiceBase
class BoardImagesServiceABC(ABC):
"""High-level service for board-image relationship management."""
@abstractmethod
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Adds an image to a board."""
pass
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
pass
@abstractmethod
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images for a board."""
pass
@abstractmethod
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
"""Gets an image's board id, if it has one."""
pass
class BoardImagesServiceDependencies:
"""Service dependencies for the BoardImagesService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardImagesService(BoardImagesServiceABC):
_services: BoardImagesServiceDependencies
def __init__(self, services: BoardImagesServiceDependencies):
self._services = services
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.add_image_to_board(board_id, image_name)
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.remove_image_from_board(board_id, image_name)
def get_images_for_board(
self,
board_id: str,
) -> OffsetPaginatedResults[ImageDTO]:
image_records = self._services.board_image_records.get_images_for_board(
board_id
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
board_id,
),
image_records.items,
)
)
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=image_records.offset,
limit=image_records.limit,
total=image_records.total,
)
def get_board_for_image(
self,
image_name: str,
) -> Union[str, None]:
board_id = self._services.board_image_records.get_board_for_image(image_name)
return board_id
def board_record_to_dto(
board_record: BoardRecord, cover_image_name: str | None, image_count: int
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.dict(exclude={'cover_image_name'}),
cover_image_name=cover_image_name,
image_count=image_count,
)

View File

@ -0,0 +1,329 @@
from abc import ABC, abstractmethod
from typing import Optional, cast
import sqlite3
import threading
from typing import Optional, Union
import uuid
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import (
BoardRecord,
deserialize_board_record,
)
from pydantic import BaseModel, Field, Extra
class BoardChanges(BaseModel, extra=Extra.forbid):
board_name: Optional[str] = Field(description="The board's new name.")
cover_image_name: Optional[str] = Field(
description="The name of the board's new cover image."
)
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
super().__init__(message)
class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
super().__init__(message)
class BoardRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the board record store."""
@abstractmethod
def delete(self, board_id: str) -> None:
"""Deletes a board record."""
pass
@abstractmethod
def save(
self,
board_name: str,
) -> BoardRecord:
"""Saves a board record."""
pass
@abstractmethod
def get(
self,
board_id: str,
) -> BoardRecord:
"""Gets a board record."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
"""Updates a board record."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardRecord]:
"""Gets all board records."""
pass
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `boards` table and `board_images` junction table."""
# Create the `boards` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS boards (
board_id TEXT NOT NULL PRIMARY KEY,
board_name TEXT NOT NULL,
cover_image_name TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
AFTER UPDATE
ON boards FOR EACH ROW
BEGIN
UPDATE boards SET updated_at = current_timestamp
WHERE board_id = old.board_id;
END;
"""
)
def delete(self, board_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
board_name: str,
) -> BoardRecord:
try:
board_id = str(uuid.uuid4())
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
try:
self._lock.acquire()
# Change the name of a board
if changes.board_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
self._cursor.execute(
f"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
finally:
self._lock.release()
return self.get(board_id)
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
"""
)
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](
items=boards, offset=offset, limit=limit, total=count
)
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
def get_all(
self,
) -> list[BoardRecord]:
try:
self._lock.acquire()
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
ORDER BY created_at DESC
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
return boards
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()

View File

@ -0,0 +1,185 @@
from abc import ABC, abstractmethod
from logging import Logger
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.board_images import board_record_to_dto
from invokeai.app.services.board_record_storage import (
BoardChanges,
BoardRecordStorageBase,
)
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.urls import UrlServiceBase
class BoardServiceABC(ABC):
"""High-level service for board management."""
@abstractmethod
def create(
self,
board_name: str,
) -> BoardDTO:
"""Creates a board."""
pass
@abstractmethod
def get_dto(
self,
board_id: str,
) -> BoardDTO:
"""Gets a board."""
pass
@abstractmethod
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
"""Updates a board."""
pass
@abstractmethod
def delete(
self,
board_id: str,
) -> None:
"""Deletes a board."""
pass
@abstractmethod
def get_many(
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(
self,
) -> list[BoardDTO]:
"""Gets all boards."""
pass
class BoardServiceDependencies:
"""Service dependencies for the BoardService."""
board_image_records: BoardImageRecordStorageBase
board_records: BoardRecordStorageBase
image_records: ImageRecordStorageBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
board_image_record_storage: BoardImageRecordStorageBase,
image_record_storage: ImageRecordStorageBase,
board_record_storage: BoardRecordStorageBase,
url: UrlServiceBase,
logger: Logger,
):
self.board_image_records = board_image_record_storage
self.image_records = image_record_storage
self.board_records = board_record_storage
self.urls = url
self.logger = logger
class BoardService(BoardServiceABC):
_services: BoardServiceDependencies
def __init__(self, services: BoardServiceDependencies):
self._services = services
def create(
self,
board_name: str,
) -> BoardDTO:
board_record = self._services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0)
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def update(
self,
board_id: str,
changes: BoardChanges,
) -> BoardDTO:
board_record = self._services.board_records.update(board_id, changes)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None:
self._services.board_records.delete(board_id)
def get_many(
self, offset: int = 0, limit: int = 10
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_records.get_many(offset, limit)
board_dtos = []
for r in board_records.items:
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO](
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
)
def get_all(self) -> list[BoardDTO]:
board_records = self._services.board_records.get_all()
board_dtos = []
for r in board_records:
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos

View File

@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC):
"""Saves an image record."""
pass
@abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
"""Gets the most recent image for a board."""
pass
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._lock.release()
def _create_tables(self) -> None:
"""Creates the tables for the `images` database."""
"""Creates the `images` table."""
# Create the `images` table.
self._cursor.execute(
@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
board_id TEXT,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
@ -190,7 +197,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = current_timestamp
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE image_name = old.image_name;
END;
"""
@ -259,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""",
(changes.is_intermediate, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@ -273,38 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
images_query = """--sql
SELECT images.*
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params = []
if image_origin is not None:
query_conditions += f"""AND image_origin = ?\n"""
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
## Convert the enum values to unique list of strings
# Convert the enum values to unique list of strings
category_strings = list(map(lambda c: c.value, set(categories)))
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += f"""AND is_intermediate = ?\n"""
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
if board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
query_pagination = """--sql
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
@ -321,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0]
count = cast(int, self._cursor.fetchone()[0])
except sqlite3.Error as e:
self._conn.rollback()
raise e
@ -412,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_most_recent_image_for_board(
self, board_id: str
) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
ORDER BY images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
finally:
self._lock.release()
if result is None:
return None
return deserialize_image_record(dict(result))

View File

@ -10,6 +10,7 @@ from invokeai.app.models.image import (
InvalidOriginException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
@ -49,7 +50,7 @@ class ImageServiceABC(ABC):
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
intermediate: bool = False,
is_intermediate: bool = False,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -79,7 +80,7 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_path(self, image_name: str) -> str:
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's path."""
pass
@ -101,6 +102,7 @@ class ImageServiceABC(ABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@ -114,8 +116,9 @@ class ImageServiceABC(ABC):
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
records: ImageRecordStorageBase
files: ImageFileStorageBase
image_records: ImageRecordStorageBase
image_files: ImageFileStorageBase
board_image_records: BoardImageRecordStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
@ -126,14 +129,16 @@ class ImageServiceDependencies:
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
board_image_record_storage: BoardImageRecordStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.records = image_record_storage
self.files = image_file_storage
self.image_records = image_record_storage
self.image_files = image_file_storage
self.board_image_records = board_image_record_storage
self.metadata = metadata
self.urls = url
self.logger = logger
@ -144,25 +149,8 @@ class ImageServiceDependencies:
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
def __init__(self, services: ImageServiceDependencies):
self._services = services
def create(
self,
@ -187,7 +175,7 @@ class ImageService(ImageServiceABC):
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save(
self._services.image_records.save(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
@ -202,35 +190,15 @@ class ImageService(ImageServiceABC):
metadata=metadata,
)
self._services.files.save(
self._services.image_files.save(
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_name)
thumbnail_url = self._services.urls.get_image_url(image_name, True)
image_dto = self.get_dto(image_name)
return ImageDTO(
# Non-nullable fields
image_name=image_name,
image_origin=image_origin,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
# Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
)
return image_dto
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
@ -247,7 +215,7 @@ class ImageService(ImageServiceABC):
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, changes)
self._services.image_records.update(image_name, changes)
return self.get_dto(image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
@ -258,7 +226,7 @@ class ImageService(ImageServiceABC):
def get_pil_image(self, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_name)
return self._services.image_files.get(image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
@ -268,7 +236,7 @@ class ImageService(ImageServiceABC):
def get_record(self, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_name)
return self._services.image_records.get(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
@ -278,12 +246,13 @@ class ImageService(ImageServiceABC):
def get_dto(self, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_name)
image_record = self._services.image_records.get(image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_name),
self._services.urls.get_image_url(image_name, True),
self._services.board_image_records.get_board_for_image(image_name),
)
return image_dto
@ -296,14 +265,14 @@ class ImageService(ImageServiceABC):
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return self._services.files.get_path(image_name, thumbnail)
return self._services.image_files.get_path(image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.files.validate_path(path)
return self._services.image_files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
@ -322,14 +291,16 @@ class ImageService(ImageServiceABC):
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
results = self._services.image_records.get_many(
offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
)
image_dtos = list(
@ -338,6 +309,9 @@ class ImageService(ImageServiceABC):
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(
r.image_name
),
),
results.items,
)
@ -355,8 +329,8 @@ class ImageService(ImageServiceABC):
def delete(self, image_name: str):
try:
self._services.files.delete(image_name)
self._services.records.delete(image_name)
self._services.image_files.delete(image_name)
self._services.image_records.delete(image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise

View File

@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.images import ImageService
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
@ -26,9 +28,9 @@ class InvocationServices:
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageService"
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
images: "ImageServiceABC"
boards: "BoardServiceABC"
board_images: "BoardImagesServiceABC"
graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC"
@ -39,7 +41,9 @@ class InvocationServices:
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageService",
images: "ImageServiceABC",
boards: "BoardServiceABC",
board_images: "BoardImagesServiceABC",
queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
@ -52,9 +56,12 @@ class InvocationServices:
self.logger = logger
self.latents = latents
self.images = images
self.boards = boards
self.board_images = board_images
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration
self.configuration = configuration
self.boards = boards

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import torch
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import (
@ -69,19 +69,6 @@ class ModelManagerServiceBase(ABC):
) -> bool:
pass
@abstractmethod
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name and typeof the default model, or None
if none is defined.
"""
pass
@abstractmethod
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
pass
@abstractmethod
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
@ -270,17 +257,6 @@ class ModelManagerService(ModelManagerServiceBase):
model_type,
)
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
return self.mgr.default_model()
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
"""Sets the default model to the indicated name."""
self.mgr.set_default_model(model_name, base_model, model_type)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
@ -297,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> dict:
) -> list[dict]:
# ) -> dict:
"""
Return a dict of models in the format:
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
Return a list of models.
"""
return self.mgr.list_models(base_model, model_type)

View File

@ -0,0 +1,62 @@
from typing import Optional, Union
from datetime import datetime
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.util.misc import get_iso_timestamp
class BoardRecord(BaseModel):
"""Deserialized board record."""
board_id: str = Field(description="The unique ID of the board.")
"""The unique ID of the board."""
board_name: str = Field(description="The name of the board.")
"""The name of the board."""
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime, str, None] = Field(
description="The deleted timestamp of the board."
)
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(
description="The name of the cover image of the board."
)
"""The name of the cover image of the board."""
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""
cover_image_name: Optional[str] = Field(
description="The name of the board's cover image."
)
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
"""Deserializes a board record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", "unknown")
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
return BoardRecord(
board_id=board_id,
board_name=board_name,
cover_image_name=cover_image_name,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
)

View File

@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend with URLs."""
"""Deserialized image record, enriched for the frontend."""
board_id: Union[str, None] = Field(
description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
pass
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
**image_record.dict(),
image_url=image_url,
thumbnail_url=thumbnail_url,
board_id=board_id,
)

View File

@ -266,6 +266,8 @@ class ModelManager(object):
for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary
@ -445,38 +447,6 @@ class ModelManager(object):
_cache = self.cache,
)
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
"""
Returns the name of the default model, or None
if none is defined.
"""
for model_key, model_config in self.models.items():
if model_config.default:
return self.parse_key(model_key)
for model_key, _ in self.models.items():
return self.parse_key(model_key)
else:
return None # TODO: or redo as (None, None, None)
def set_default_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> None:
"""
Set the default model. The change will not take
effect until you call model_manager.commit()
"""
model_key = self.model_key(model_name, base_model, model_type)
if model_key not in self.models:
raise Exception(f"Unknown model: {model_key}")
for cur_model_key, config in self.models.items():
config.default = cur_model_key == model_key
def model_info(
self,
model_name: str,
@ -503,9 +473,9 @@ class ModelManager(object):
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> Dict[str, Dict[str, str]]:
) -> list[dict]:
"""
Return a dict of models, in format [base_model][model_type][model_name]
Return a list of models.
Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model
@ -513,7 +483,7 @@ class ModelManager(object):
object derived from models.yaml
"""
models = dict()
models = []
for model_key in sorted(self.models, key=str.casefold):
model_config = self.models[model_key]
@ -523,18 +493,16 @@ class ModelManager(object):
if model_type is not None and cur_model_type != model_type:
continue
if cur_base_model not in models:
models[cur_base_model] = dict()
if cur_model_type not in models[cur_base_model]:
models[cur_base_model][cur_model_type] = dict()
models[cur_base_model][cur_model_type][cur_model_name] = dict(
model_dict = dict(
**model_config.dict(exclude_defaults=True),
# OpenAPIModelInfoBase
name=cur_model_name,
base_model=cur_base_model,
type=cur_model_type,
)
models.append(model_dict)
return models
def print_models(self) -> None:
@ -646,7 +614,9 @@ class ModelManager(object):
model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
# alias for config file
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
yaml_str = OmegaConf.to_yaml(data_to_save)
config_file_path = conf_file or self.config_path

View File

@ -1,3 +1,7 @@
import inspect
from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel
@ -29,10 +33,63 @@ MODEL_CLASSES = {
#},
}
def get_all_model_configs():
configs = set()
for models in MODEL_CLASSES.values():
for _, model in models.items():
configs.update(model._get_configs().values())
configs.discard(None)
return list(configs) # TODO: set, list or tuple
MODEL_CONFIGS = list()
OPENAPI_MODEL_CONFIGS = list()
class OpenAPIModelInfoBase(BaseModel):
name: str
base_model: BaseModelType
type: ModelType
for base_model, models in MODEL_CLASSES.items():
for model_type, model_class in models.items():
model_configs = set(model_class._get_configs().values())
model_configs.discard(None)
MODEL_CONFIGS.extend(model_configs)
for cfg in model_configs:
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
openapi_cfg_name = model_name + cfg_name
if openapi_cfg_name in vars():
continue
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
__annotations__ = dict(
type=Literal[model_type.value],
),
))
#globals()[openapi_cfg_name] = api_wrapper
vars()[openapi_cfg_name] = api_wrapper
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
def get_model_config_enums():
enums = list()
for model_config in MODEL_CONFIGS:
fields = inspect.get_annotations(model_config)
try:
field = fields["model_format"]
except:
raise Exception("format field not found")
# model_format: None
# model_format: SomeModelFormat
# model_format: Literal[SomeModelFormat.Diffusers]
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
enums.append(field)
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
enums.append(type(field.__args__[0]))
elif field is None:
pass
else:
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
return enums

View File

@ -48,12 +48,10 @@ class ModelError(str, Enum):
class ModelConfigBase(BaseModel):
path: str # or Path
#name: str # not included as present in model key
description: Optional[str] = Field(None)
format: Optional[str] = Field(None)
default: Optional[bool] = Field(False)
model_format: Optional[str] = Field(None)
# do not save to config
error: Optional[ModelError] = Field(None, exclude=True)
error: Optional[ModelError] = Field(None)
class Config:
use_enum_values = True
@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta):
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):
return None
elif any(t is None for t in subtypes):
raise Exception(f"Unsupported definition: {subtypes}")
if subtypes[0] in ["diffusers", "transformers"]:
res_type = sys.modules[subtypes[0]]
subtypes = subtypes[1:]
@ -122,47 +125,41 @@ class ModelBase(metaclass=ABCMeta):
continue
fields = inspect.get_annotations(value)
if "format" not in fields:
raise Exception("Invalid config definition - format field not found")
try:
field = fields["model_format"]
except:
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
format_type = typing.get_origin(fields["format"])
if format_type not in {None, Literal, Union}:
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
for model_format in field:
configs[model_format.value] = value
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
for model_format in field.__args__:
configs[model_format.value] = value
elif field is None:
configs[None] = value
if format_type == Union:
f_fields = fields["format"].__args__
else:
f_fields = (fields["format"],)
for field in f_fields:
if field is None:
format_name = None
else:
format_name = field.__args__[0]
configs[format_name] = value # TODO: error when override(multiple)?
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
cls.__configs = configs
return cls.__configs
@classmethod
def create_config(cls, **kwargs) -> ModelConfigBase:
if "format" not in kwargs:
raise Exception("Field 'format' not found in model config")
if "model_format" not in kwargs:
raise Exception("Field 'model_format' not found in model config")
configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs)
return configs[kwargs["model_format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
return cls.create_config(
path=path,
format=cls.detect_format(path),
model_format=cls.detect_format(path),
)
@classmethod

View File

@ -1,5 +1,6 @@
import os
import torch
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
@ -14,12 +15,16 @@ from .base import (
classproperty,
)
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
model_format: ControlNetModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
return ControlNetModelFormat.Diffusers
else:
return "checkpoint"
return ControlNetModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path

View File

@ -1,5 +1,6 @@
import os
import torch
from enum import Enum
from typing import Optional, Union, Literal
from .base import (
ModelBase,
@ -12,11 +13,15 @@ from .base import (
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
Diffusers = "diffusers"
class LoRAModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["lycoris"], Literal["diffusers"]]
model_format: LoRAModelFormat # TODO:
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Lora
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
return LoRAModelFormat.Diffusers
else:
return "lycoris"
return LoRAModelFormat.LyCORIS
@classmethod
def convert_if_required(
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == "diffusers":
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
# TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported")
else:

View File

@ -1,5 +1,6 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
@ -19,16 +20,19 @@ from .base import (
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint":
if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers":
elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel):
return cls.create_config(
path=path,
format=model_format,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
return StableDiffusion1ModelFormat.Diffusers
else:
return "checkpoint"
return StableDiffusion1ModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
else:
return model_path
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint":
if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers":
elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel):
return cls.create_config(
path=path,
format=model_format,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
return StableDiffusion2ModelFormat.Diffusers
else:
return "checkpoint"
return StableDiffusion2ModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -281,8 +288,8 @@ def _convert_ckpt_and_cache(
prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2:
upcast_attention = config.upcast_attention
prediction_type = config.prediction_type
upcast_attention = model_config.upcast_attention
prediction_type = model_config.prediction_type
else:
raise Exception(f"Unknown model provided: {version}")

View File

@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
#model_size: int
class Config(ModelConfigBase):
format: None
model_format: None
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.TextualInversion

View File

@ -1,5 +1,7 @@
import os
import torch
import safetensors
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from .base import (
@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
class VaeModel(ModelBase):
#vae_class: Type
#model_size: int
class Config(ModelConfigBase):
format: Union[Literal["checkpoint"], Literal["diffusers"]]
model_format: VaeModelFormat
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.Vae
@ -70,9 +76,9 @@ class VaeModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if os.path.isdir(path):
return "diffusers"
return VaeModelFormat.Diffusers
else:
return "checkpoint"
return VaeModelFormat.Checkpoint
@classmethod
def convert_if_required(
@ -82,7 +88,7 @@ class VaeModel(ModelBase):
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers":
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
return _convert_vae_ckpt_and_cache(
weights_path=model_path,
output_path=output_path,

View File

@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { useListModelsQuery } from 'services/apiSlice';
const DEFAULT_CONFIG = {};
@ -45,6 +47,18 @@ const App = ({
const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet',
});
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
const { data: embeddingModels } = useListModelsQuery({
model_type: 'embedding',
});
const [loadingOverridden, setLoadingOverridden] = useState(false);
const dispatch = useAppDispatch();
@ -143,6 +157,7 @@ const App = ({
</Portal>
</Grid>
<DeleteImageModal />
<UpdateImageBoardModal />
<Toaster />
<GlobalHotkeys />
</>

View File

@ -21,6 +21,8 @@ import {
DeleteImageContext,
DeleteImageContextProvider,
} from 'app/contexts/DeleteImageContext';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
@ -76,11 +78,13 @@ const InvokeAIUI = ({
<ThemeLocaleProvider>
<ImageDndContext>
<DeleteImageContextProvider>
<App
config={config}
headerComponent={headerComponent}
setIsReady={setIsReady}
/>
<AddImageToBoardContextProvider>
<App
config={config}
headerComponent={headerComponent}
setIsReady={setIsReady}
/>
</AddImageToBoardContextProvider>
</DeleteImageContextProvider>
</ImageDndContext>
</ThemeLocaleProvider>

View File

@ -0,0 +1,89 @@
import { useDisclosure } from '@chakra-ui/react';
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
import { ImageDTO } from 'services/api';
import { useAddImageToBoardMutation } from 'services/apiSlice';
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};
type AddImageToBoardContextValue = {
/**
* Whether the move image dialog is open.
*/
isOpen: boolean;
/**
* Closes the move image dialog.
*/
onClose: () => void;
/**
* The image pending movement
*/
image?: ImageDTO;
onClickAddToBoard: (image: ImageDTO) => void;
handleAddToBoard: (boardId: string) => void;
};
export const AddImageToBoardContext =
createContext<AddImageToBoardContextValue>({
isOpen: false,
onClose: () => undefined,
onClickAddToBoard: () => undefined,
handleAddToBoard: () => undefined,
});
type Props = PropsWithChildren;
export const AddImageToBoardContextProvider = (props: Props) => {
const [imageToMove, setImageToMove] = useState<ImageDTO>();
const { isOpen, onOpen, onClose } = useDisclosure();
const [addImageToBoard, result] = useAddImageToBoardMutation();
// Clean up after deleting or dismissing the modal
const closeAndClearImageToDelete = useCallback(() => {
setImageToMove(undefined);
onClose();
}, [onClose]);
const onClickAddToBoard = useCallback(
(image?: ImageDTO) => {
if (!image) {
return;
}
setImageToMove(image);
onOpen();
},
[setImageToMove, onOpen]
);
const handleAddToBoard = useCallback(
(boardId: string) => {
if (imageToMove) {
addImageToBoard({
board_id: boardId,
image_name: imageToMove.image_name,
});
closeAndClearImageToDelete();
}
},
[addImageToBoard, closeAndClearImageToDelete, imageToMove]
);
return (
<AddImageToBoardContext.Provider
value={{
isOpen,
image: imageToMove,
onClose: closeAndClearImageToDelete,
onClickAddToBoard,
handleAddToBoard,
}}
>
{props.children}
</AddImageToBoardContext.Provider>
);
};

View File

@ -35,25 +35,23 @@ export const selectImageUsage = createSelector(
(state: RootState, image_name?: string) => image_name,
],
(generation, canvas, nodes, controlNet, image_name) => {
const isInitialImage = generation.initialImage?.image_name === image_name;
const isInitialImage = generation.initialImage?.imageName === image_name;
const isCanvasImage = canvas.layerState.objects.some(
(obj) => obj.kind === 'image' && obj.image.image_name === image_name
(obj) => obj.kind === 'image' && obj.imageName === image_name
);
const isNodesImage = nodes.nodes.some((node) => {
return some(
node.data.inputs,
(input) =>
input.type === 'image' && input.value?.image_name === image_name
(input) => input.type === 'image' && input.value === image_name
);
});
const isControlNetImage = some(
controlNet.controlNets,
(c) =>
c.controlImage?.image_name === image_name ||
c.processedControlImage?.image_name === image_name
c.controlImage === image_name || c.processedControlImage === image_name
);
const imageUsage: ImageUsage = {

View File

@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es';
@ -18,7 +17,6 @@ const serializationDenylist: {
gallery: galleryPersistDenylist,
generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist,
models: modelsPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,

View File

@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice';
import { initialModelsState } from 'features/system/store/modelSlice';
import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice';
@ -21,7 +20,6 @@ const initialStates: {
gallery: initialGalleryState,
generation: initialGenerationState,
lightbox: initialLightboxState,
models: initialModelsState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
system: initialSystemState,

View File

@ -73,6 +73,15 @@ import { addImageCategoriesChangedListener } from './listeners/imageCategoriesCh
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
import {
addImageAddedToBoardFulfilledListener,
addImageAddedToBoardRejectedListener,
} from './listeners/imageAddedToBoard';
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
import {
addImageRemovedFromBoardFulfilledListener,
addImageRemovedFromBoardRejectedListener,
} from './listeners/imageRemovedFromBoard';
export const listenerMiddleware = createListenerMiddleware();
@ -92,6 +101,12 @@ export type AppListenerEffect = ListenerEffect<
AppDispatch
>;
/**
* The RTK listener middleware is a lightweight alternative sagas/observables.
*
* Most side effect logic should live in a listener.
*/
// Image uploaded
addImageUploadedFulfilledListener();
addImageUploadedRejectedListener();
@ -183,3 +198,10 @@ addControlNetAutoProcessListener();
// Update image URLs on connect
addUpdateImageUrlsOnConnectListener();
// Boards
addImageAddedToBoardFulfilledListener();
addImageAddedToBoardRejectedListener();
addImageRemovedFromBoardFulfilledListener();
addImageRemovedFromBoardRejectedListener();
addBoardIdSelectedListener();

View File

@ -0,0 +1,99 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { boardIdSelected } from 'features/gallery/store/boardSlice';
import { selectImagesAll } from 'features/gallery/store/imagesSlice';
import { IMAGES_PER_PAGE, receivedPageOfImages } from 'services/thunks/image';
import { api } from 'services/apiSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'boards' });
export const addBoardIdSelectedListener = () => {
startAppListening({
actionCreator: boardIdSelected,
effect: (action, { getState, dispatch }) => {
const boardId = action.payload;
// we need to check if we need to fetch more images
const state = getState();
const allImages = selectImagesAll(state);
if (!boardId) {
// a board was unselected
dispatch(imageSelected(allImages[0]?.image_name));
return;
}
const { categories } = state.images;
const filteredImages = allImages.filter((i) => {
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
return isInCategory && isInSelectedBoard;
});
// get the board from the cache
const { data: boards } = api.endpoints.listAllBoards.select()(state);
const board = boards?.find((b) => b.board_id === boardId);
if (!board) {
// can't find the board in cache...
dispatch(imageSelected(allImages[0]?.image_name));
return;
}
dispatch(imageSelected(board.cover_image_name));
// if we haven't loaded one full page of images from this board, load more
if (
filteredImages.length < board.image_count &&
filteredImages.length < IMAGES_PER_PAGE
) {
dispatch(receivedPageOfImages({ categories, boardId }));
}
},
});
};
export const addBoardIdSelected_changeSelectedImage_listener = () => {
startAppListening({
actionCreator: boardIdSelected,
effect: (action, { getState, dispatch }) => {
const boardId = action.payload;
const state = getState();
// we need to check if we need to fetch more images
if (!boardId) {
// a board was unselected - we don't need to do anything
return;
}
const { categories } = state.images;
const filteredImages = selectImagesAll(state).filter((i) => {
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
return isInCategory && isInSelectedBoard;
});
// get the board from the cache
const { data: boards } = api.endpoints.listAllBoards.select()(state);
const board = boards?.find((b) => b.board_id === boardId);
if (!board) {
// can't find the board in cache...
return;
}
// if we haven't loaded one full page of images from this board, load more
if (
filteredImages.length < board.image_count &&
filteredImages.length < IMAGES_PER_PAGE
) {
dispatch(receivedPageOfImages({ categories, boardId }));
}
},
});
};

View File

@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => {
[controlNet.processorNode.id]: {
...controlNet.processorNode,
is_intermediate: true,
image: pick(controlNet.controlImage, ['image_name']),
image: { image_name: controlNet.controlImage },
},
},
};
@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => {
dispatch(
controlNetProcessedImageChanged({
controlNetId,
processedControlImage,
processedControlImage: processedControlImage.image_name,
})
);
}

View File

@ -0,0 +1,40 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image';
import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'boards' });
export const addImageAddedToBoardFulfilledListener = () => {
startAppListening({
matcher: api.endpoints.addImageToBoard.matchFulfilled,
effect: (action, { getState, dispatch }) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
moduleLog.debug(
{ data: { board_id, image_name } },
'Image added to board'
);
dispatch(
imageMetadataReceived({
imageName: image_name,
})
);
},
});
};
export const addImageAddedToBoardRejectedListener = () => {
startAppListening({
matcher: api.endpoints.addImageToBoard.matchRejected,
effect: (action, { getState, dispatch }) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
moduleLog.debug(
{ data: { board_id, image_name } },
'Problem adding image to board'
);
},
});
};

View File

@ -12,12 +12,16 @@ export const addImageCategoriesChangedListener = () => {
startAppListening({
actionCreator: imageCategoriesChanged,
effect: (action, { getState, dispatch }) => {
const filteredImagesCount = selectFilteredImagesAsArray(
getState()
).length;
const state = getState();
const filteredImagesCount = selectFilteredImagesAsArray(state).length;
if (!filteredImagesCount) {
dispatch(receivedPageOfImages());
dispatch(
receivedPageOfImages({
categories: action.payload,
boardId: state.boards.selectedBoardId,
})
);
}
},
});

View File

@ -6,15 +6,15 @@ import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import {
imageRemoved,
selectImagesEntities,
selectImagesIds,
} from 'features/gallery/store/imagesSlice';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
const moduleLog = log.child({ namespace: 'image' });
/**
* Called when the user requests an image deletion
@ -22,7 +22,7 @@ const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
export const addRequestedImageDeletionListener = () => {
startAppListening({
actionCreator: requestedImageDeletion,
effect: (action, { dispatch, getState }) => {
effect: async (action, { dispatch, getState, condition }) => {
const { image, imageUsage } = action.payload;
const { image_name } = image;
@ -30,9 +30,8 @@ export const addRequestedImageDeletionListener = () => {
const state = getState();
const selectedImage = state.gallery.selectedImage;
if (selectedImage && selectedImage.image_name === image_name) {
if (selectedImage === image_name) {
const ids = selectImagesIds(state);
const entities = selectImagesEntities(state);
const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name
@ -48,10 +47,8 @@ export const addRequestedImageDeletionListener = () => {
const newSelectedImageId = filteredIds[newSelectedImageIndex];
const newSelectedImage = entities[newSelectedImageId];
if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImage));
dispatch(imageSelected(newSelectedImageId as string));
} else {
dispatch(imageSelected());
}
@ -79,7 +76,21 @@ export const addRequestedImageDeletionListener = () => {
dispatch(imageRemoved(image_name));
// Delete from server
dispatch(imageDeleted({ imageName: image_name }));
const { requestId } = dispatch(imageDeleted({ imageName: image_name }));
// Wait for successful deletion, then trigger boards to re-fetch
const wasImageDeleted = await condition(
(action): action is ReturnType<typeof imageDeleted.fulfilled> =>
imageDeleted.fulfilled.match(action) &&
action.meta.requestId === requestId,
30000
);
if (wasImageDeleted) {
dispatch(
api.util.invalidateTags([{ type: 'Board', id: image.board_id }])
);
}
},
});
};

View File

@ -0,0 +1,40 @@
import { log } from 'app/logging/useLogger';
import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image';
import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'boards' });
export const addImageRemovedFromBoardFulfilledListener = () => {
startAppListening({
matcher: api.endpoints.removeImageFromBoard.matchFulfilled,
effect: (action, { getState, dispatch }) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
moduleLog.debug(
{ data: { board_id, image_name } },
'Image added to board'
);
dispatch(
imageMetadataReceived({
imageName: image_name,
})
);
},
});
};
export const addImageRemovedFromBoardRejectedListener = () => {
startAppListening({
matcher: api.endpoints.removeImageFromBoard.matchRejected,
effect: (action, { getState, dispatch }) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
moduleLog.debug(
{ data: { board_id, image_name } },
'Problem adding image to board'
);
},
});
};

View File

@ -46,7 +46,12 @@ export const addImageUploadedFulfilledListener = () => {
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
const { controlNetId } = postUploadAction;
dispatch(controlNetImageChanged({ controlNetId, controlImage: image }));
dispatch(
controlNetImageChanged({
controlNetId,
controlImage: image.image_name,
})
);
return;
}

View File

@ -1,9 +1,8 @@
import { startAppListening } from '../..';
import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { receivedPageOfImages } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
@ -15,16 +14,17 @@ export const addSocketConnectedEventListener = () => {
moduleLog.debug({ timestamp }, 'Connected');
const { models, nodes, config, images } = getState();
const { nodes, config, images } = getState();
const { disabledTabs } = config;
if (!images.ids.length) {
dispatch(receivedPageOfImages());
}
if (!models.ids.length) {
dispatch(receivedModels());
dispatch(
receivedPageOfImages({
categories: ['general'],
isIntermediate: false,
})
);
}
if (!nodes.schema && !disabledTabs.includes('nodes')) {

View File

@ -9,6 +9,7 @@ import { imageMetadataReceived } from 'services/thunks/image';
import { sessionCanceled } from 'services/thunks/session';
import { isImageOutput } from 'services/types/guards';
import { progressImageSet } from 'features/system/store/systemSlice';
import { api } from 'services/apiSlice';
const moduleLog = log.child({ namespace: 'socketio' });
const nodeDenylist = ['dataURL_image'];
@ -24,7 +25,8 @@ export const addInvocationCompleteEventListener = () => {
const sessionId = action.payload.data.graph_execution_state_id;
const { cancelType, isCancelScheduled } = getState().system;
const { cancelType, isCancelScheduled, boardIdToAddTo } =
getState().system;
// Handle scheduled cancelation
if (cancelType === 'scheduled' && isCancelScheduled) {
@ -57,6 +59,15 @@ export const addInvocationCompleteEventListener = () => {
dispatch(addImageToStagingArea(imageDTO));
}
if (boardIdToAddTo && !imageDTO.is_intermediate) {
dispatch(
api.endpoints.addImageToBoard.initiate({
board_id: boardIdToAddTo,
image_name,
})
);
}
dispatch(progressImageSet(null));
}
// pass along the socket event as an application action

View File

@ -22,15 +22,15 @@ const selectAllUsedImages = createSelector(
selectImagesEntities,
],
(generation, canvas, nodes, controlNet, imageEntities) => {
const allUsedImages: ImageDTO[] = [];
const allUsedImages: string[] = [];
if (generation.initialImage) {
allUsedImages.push(generation.initialImage);
allUsedImages.push(generation.initialImage.imageName);
}
canvas.layerState.objects.forEach((obj) => {
if (obj.kind === 'image') {
allUsedImages.push(obj.image);
allUsedImages.push(obj.imageName);
}
});
@ -53,7 +53,7 @@ const selectAllUsedImages = createSelector(
forEach(imageEntities, (image) => {
if (image) {
allUsedImages.push(image);
allUsedImages.push(image.image_name);
}
});
@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => {
`Fetching new image URLs for ${allUsedImages.length} images`
);
allUsedImages.forEach(({ image_name }) => {
allUsedImages.forEach((image_name) => {
dispatch(
imageUrlsReceived({
imageName: image_name,

View File

@ -5,40 +5,39 @@ import {
configureStore,
} from '@reduxjs/toolkit';
import { rememberReducer, rememberEnhancer } from 'redux-remember';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import imagesReducer from 'features/gallery/store/imagesSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
// import sessionReducer from 'features/system/store/sessionSlice';
import configReducer from 'features/system/store/configSlice';
import uiReducer from 'features/ui/store/uiSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { LOCALSTORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
import { LOCALSTORAGE_PREFIX } from './constants';
import { api } from 'services/apiSlice';
const allReducers = {
canvas: canvasReducer,
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
models: modelsReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
system: systemReducer,
@ -47,7 +46,9 @@ const allReducers = {
hotkeys: hotkeysReducer,
images: imagesReducer,
controlNet: controlNetReducer,
boards: boardsReducer,
// session: sessionReducer,
[api.reducerPath]: api.reducer,
};
const rootReducer = combineReducers(allReducers);
@ -59,12 +60,12 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'gallery',
'generation',
'lightbox',
// 'models',
'nodes',
'postprocessing',
'system',
'ui',
'controlNet',
// 'boards',
// 'hotkeys',
// 'config',
];
@ -84,6 +85,7 @@ export const store = configureStore({
immutableCheck: false,
serializableCheck: false,
})
.concat(api.middleware)
.concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware),
devTools: {

View File

@ -9,7 +9,7 @@ import {
import { useDraggable, useDroppable } from '@dnd-kit/core';
import { useCombinedRefs } from '@dnd-kit/utilities';
import IAIIconButton from 'common/components/IAIIconButton';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { AnimatePresence } from 'framer-motion';
import { ReactElement, SyntheticEvent, useCallback } from 'react';
@ -53,7 +53,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
isDropDisabled = false,
isDragDisabled = false,
isUploadDisabled = false,
fallback = <IAIImageFallback />,
fallback = <IAIImageLoadingFallback />,
payloadImage,
minSize = 24,
postUploadAction,

View File

@ -1,10 +1,20 @@
import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react';
import {
As,
Flex,
FlexProps,
Icon,
IconProps,
Spinner,
SpinnerProps,
} from '@chakra-ui/react';
import { ReactElement } from 'react';
import { FaImage } from 'react-icons/fa';
type Props = FlexProps & {
spinnerProps?: SpinnerProps;
};
export const IAIImageFallback = (props: Props) => {
export const IAIImageLoadingFallback = (props: Props) => {
const { spinnerProps, ...rest } = props;
const { sx, ...restFlexProps } = rest;
return (
@ -25,3 +35,35 @@ export const IAIImageFallback = (props: Props) => {
</Flex>
);
};
type IAINoImageFallbackProps = {
flexProps?: FlexProps;
iconProps?: IconProps;
as?: As;
};
export const IAINoImageFallback = (props: IAINoImageFallbackProps) => {
const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} };
const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} };
return (
<Flex
sx={{
bg: 'base.900',
opacity: 0.7,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
...flexSx,
}}
{...restFlexProps}
>
<Icon
as={props.as ?? FaImage}
sx={{ color: 'base.700', ...iconSx }}
{...restIconProps}
/>
</Flex>
);
};

View File

@ -1,14 +1,21 @@
import { Image } from 'react-konva';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { Image, Rect } from 'react-konva';
import { useGetImageDTOQuery } from 'services/apiSlice';
import useImage from 'use-image';
import { CanvasImage } from '../store/canvasTypes';
type IAICanvasImageProps = {
url: string;
x: number;
y: number;
canvasImage: CanvasImage;
};
const IAICanvasImage = (props: IAICanvasImageProps) => {
const { url, x, y } = props;
const [image] = useImage(url, 'anonymous');
const { width, height, x, y, imageName } = props.canvasImage;
const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
if (!imageDTO) {
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
}
return <Image x={x} y={y} image={image} listening={false} />;
};

View File

@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => {
<Group name="outpainting-objects" listening={false}>
{objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) {
return (
<IAICanvasImage
key={i}
x={obj.x}
y={obj.y}
url={obj.image.image_url}
/>
);
return <IAICanvasImage key={i} canvasImage={obj} />;
} else if (isCanvasBaseLine(obj)) {
const line = (
<Line

View File

@ -59,11 +59,7 @@ const IAICanvasStagingArea = (props: Props) => {
return (
<Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage
url={currentStagingAreaImage.image.image_url}
x={x}
y={y}
/>
<IAICanvasImage canvasImage={currentStagingAreaImage} />
)}
{shouldShowStagingOutline && (
<Group>

View File

@ -203,7 +203,7 @@ export const canvasSlice = createSlice({
y: 0,
width: width,
height: height,
image: image,
imageName: image.image_name,
},
],
};
@ -325,7 +325,7 @@ export const canvasSlice = createSlice({
kind: 'image',
layer: 'base',
...state.layerState.stagingArea.boundingBox,
image,
imageName: image.image_name,
});
state.layerState.stagingArea.selectedImageIndex =
@ -865,25 +865,25 @@ export const canvasSlice = createSlice({
state.doesCanvasNeedScaling = true;
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
state.layerState.objects.forEach((object) => {
if (object.kind === 'image') {
if (object.image.image_name === image_name) {
object.image.image_url = image_url;
object.image.thumbnail_url = thumbnail_url;
}
}
});
// state.layerState.objects.forEach((object) => {
// if (object.kind === 'image') {
// if (object.image.image_name === image_name) {
// object.image.image_url = image_url;
// object.image.thumbnail_url = thumbnail_url;
// }
// }
// });
state.layerState.stagingArea.images.forEach((stagedImage) => {
if (stagedImage.image.image_name === image_name) {
stagedImage.image.image_url = image_url;
stagedImage.image.thumbnail_url = thumbnail_url;
}
});
});
// state.layerState.stagingArea.images.forEach((stagedImage) => {
// if (stagedImage.image.image_name === image_name) {
// stagedImage.image.image_url = image_url;
// stagedImage.image.thumbnail_url = thumbnail_url;
// }
// });
// });
},
});

View File

@ -38,7 +38,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
image: ImageDTO;
imageName: string;
};
export type CanvasMaskLine = {

View File

@ -11,9 +11,11 @@ import IAIDndImage from 'common/components/IAIDndImage';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { AnimatePresence, motion } from 'framer-motion';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector(
controlNetSelector,
@ -31,24 +33,45 @@ type Props = {
const ControlNetImagePreview = (props: Props) => {
const { imageSx } = props;
const { controlNetId, controlImage, processedControlImage, processorType } =
props.controlNet;
const {
controlNetId,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
} = props.controlNet;
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
const {
data: controlImage,
isLoading: isLoadingControlImage,
isError: isErrorControlImage,
isSuccess: isSuccessControlImage,
} = useGetImageDTOQuery(controlImageName ?? skipToken);
const {
data: processedControlImage,
isLoading: isLoadingProcessedControlImage,
isError: isErrorProcessedControlImage,
isSuccess: isSuccessProcessedControlImage,
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (controlImage?.image_name === droppedImage.image_name) {
if (controlImageName === droppedImage.image_name) {
return;
}
setIsMouseOverImage(false);
dispatch(
controlNetImageChanged({ controlNetId, controlImage: droppedImage })
controlNetImageChanged({
controlNetId,
controlImage: droppedImage.image_name,
})
);
},
[controlImage, controlNetId, dispatch]
[controlImageName, controlNetId, dispatch]
);
const handleResetControlImage = useCallback(() => {
@ -150,7 +173,7 @@ const ControlNetImagePreview = (props: Props) => {
h: 'full',
}}
>
<IAIImageFallback />
<IAIImageLoadingFallback />
</Box>
)}
{controlImage && (

View File

@ -39,8 +39,8 @@ export type ControlNetConfig = {
weight: number;
beginStepPct: number;
endStepPct: number;
controlImage: ImageDTO | null;
processedControlImage: ImageDTO | null;
controlImage: string | null;
processedControlImage: string | null;
processorType: ControlNetProcessorType;
processorNode: RequiredControlNetProcessorNode;
shouldAutoConfig: boolean;
@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({
},
controlNetAddedFromImage: (
state,
action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
action: PayloadAction<{ controlNetId: string; controlImage: string }>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId] = {
@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
controlImage: ImageDTO | null;
controlImage: string | null;
}>
) => {
const { controlNetId, controlImage } = action.payload;
@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({
state,
action: PayloadAction<{
controlNetId: string;
processedControlImage: ImageDTO | null;
processedControlImage: string | null;
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({
// Preemptively remove the image from the gallery
const { imageName } = action.meta.arg;
forEach(state.controlNets, (c) => {
if (c.controlImage?.image_name === imageName) {
if (c.controlImage === imageName) {
c.controlImage = null;
c.processedControlImage = null;
}
if (c.processedControlImage?.image_name === imageName) {
if (c.processedControlImage === imageName) {
c.processedControlImage = null;
}
});
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
forEach(state.controlNets, (c) => {
if (c.controlImage?.image_name === image_name) {
c.controlImage.image_url = image_url;
c.controlImage.thumbnail_url = thumbnail_url;
}
if (c.processedControlImage?.image_name === image_name) {
c.processedControlImage.image_url = image_url;
c.processedControlImage.thumbnail_url = thumbnail_url;
}
});
});
// forEach(state.controlNets, (c) => {
// if (c.controlImage?.image_name === image_name) {
// c.controlImage.image_url = image_url;
// c.controlImage.thumbnail_url = thumbnail_url;
// }
// if (c.processedControlImage?.image_name === image_name) {
// c.processedControlImage.image_url = image_url;
// c.processedControlImage.thumbnail_url = thumbnail_url;
// }
// });
// });
builder.addCase(appSocketInvocationError, (state, action) => {
state.pendingControlImages = [];

View File

@ -0,0 +1,27 @@
import IAIButton from 'common/components/IAIButton';
import { useCallback } from 'react';
import { useCreateBoardMutation } from 'services/apiSlice';
const DEFAULT_BOARD_NAME = 'My Board';
const AddBoardButton = () => {
const [createBoard, { isLoading }] = useCreateBoardMutation();
const handleCreateBoard = useCallback(() => {
createBoard(DEFAULT_BOARD_NAME);
}, [createBoard]);
return (
<IAIButton
isLoading={isLoading}
aria-label="Add Board"
onClick={handleCreateBoard}
size="sm"
sx={{ px: 4 }}
>
Add Board
</IAIButton>
);
};
export default AddBoardButton;

View File

@ -0,0 +1,93 @@
import { Flex, Text } from '@chakra-ui/react';
import { FaImages } from 'react-icons/fa';
import { boardIdSelected } from '../../store/boardSlice';
import { useDispatch } from 'react-redux';
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
import { AnimatePresence } from 'framer-motion';
import { SelectedItemOverlay } from '../SelectedItemOverlay';
import { useCallback } from 'react';
import { ImageDTO } from 'services/api';
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
import { useDroppable } from '@dnd-kit/core';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
const dispatch = useDispatch();
const handleAllImagesBoardClick = () => {
dispatch(boardIdSelected());
};
const [removeImageFromBoard, { isLoading }] =
useRemoveImageFromBoardMutation();
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (!droppedImage.board_id) {
return;
}
removeImageFromBoard({
board_id: droppedImage.board_id,
image_name: droppedImage.image_name,
});
},
[removeImageFromBoard]
);
const {
isOver,
setNodeRef,
active: isDropActive,
} = useDroppable({
id: `board_droppable_all_images`,
data: {
handleDrop,
},
});
return (
<Flex
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
borderRadius: 'base',
}}
onClick={handleAllImagesBoardClick}
>
<Flex
ref={setNodeRef}
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
w: 'full',
aspectRatio: '1/1',
}}
>
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaImages} />
<AnimatePresence>
{isSelected && <SelectedItemOverlay />}
</AnimatePresence>
<AnimatePresence>
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</Flex>
<Text
sx={{
color: isSelected ? 'base.50' : 'base.200',
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
}}
>
All Images
</Text>
</Flex>
);
};
export default AllImagesBoard;

View File

@ -0,0 +1,134 @@
import {
Collapse,
Flex,
Grid,
IconButton,
Input,
InputGroup,
InputRightElement,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import {
boardsSelector,
setBoardSearchText,
} from 'features/gallery/store/boardSlice';
import { memo, useState } from 'react';
import HoverableBoard from './HoverableBoard';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import AddBoardButton from './AddBoardButton';
import AllImagesBoard from './AllImagesBoard';
import { CloseIcon } from '@chakra-ui/icons';
import { useListAllBoardsQuery } from 'services/apiSlice';
const selector = createSelector(
[boardsSelector],
(boardsState) => {
const { selectedBoardId, searchText } = boardsState;
return { selectedBoardId, searchText };
},
defaultSelectorOptions
);
type Props = {
isOpen: boolean;
};
const BoardsList = (props: Props) => {
const { isOpen } = props;
const dispatch = useAppDispatch();
const { selectedBoardId, searchText } = useAppSelector(selector);
const { data: boards } = useListAllBoardsQuery();
const filteredBoards = searchText
? boards?.filter((board) =>
board.board_name.toLowerCase().includes(searchText.toLowerCase())
)
: boards;
const [searchMode, setSearchMode] = useState(false);
const handleBoardSearch = (searchTerm: string) => {
setSearchMode(searchTerm.length > 0);
dispatch(setBoardSearchText(searchTerm));
};
const clearBoardSearch = () => {
setSearchMode(false);
dispatch(setBoardSearchText(''));
};
return (
<Collapse in={isOpen} animateOpacity>
<Flex
sx={{
flexDir: 'column',
gap: 2,
bg: 'base.800',
borderRadius: 'base',
p: 2,
mt: 2,
}}
>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<InputGroup>
<Input
placeholder="Search Boards..."
value={searchText}
onChange={(e) => {
handleBoardSearch(e.target.value);
}}
/>
{searchText && searchText.length && (
<InputRightElement>
<IconButton
onClick={clearBoardSearch}
size="xs"
variant="ghost"
aria-label="Clear Search"
icon={<CloseIcon boxSize={3} />}
/>
</InputRightElement>
)}
</InputGroup>
<AddBoardButton />
</Flex>
<OverlayScrollbarsComponent
defer
style={{ height: '100%', width: '100%' }}
options={{
scrollbars: {
visibility: 'auto',
autoHide: 'move',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
}}
>
<Grid
className="list-container"
sx={{
gap: 2,
gridTemplateRows: '5.5rem 5.5rem',
gridAutoFlow: 'column dense',
gridAutoColumns: '4rem',
}}
>
{!searchMode && <AllImagesBoard isSelected={!selectedBoardId} />}
{filteredBoards &&
filteredBoards.map((board) => (
<HoverableBoard
key={board.board_id}
board={board}
isSelected={selectedBoardId === board.board_id}
/>
))}
</Grid>
</OverlayScrollbarsComponent>
</Flex>
</Collapse>
);
};
export default memo(BoardsList);

View File

@ -0,0 +1,193 @@
import {
Badge,
Box,
Editable,
EditableInput,
EditablePreview,
Flex,
Image,
MenuItem,
MenuList,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { memo, useCallback } from 'react';
import { FaFolder, FaTrash } from 'react-icons/fa';
import { ContextMenu } from 'chakra-ui-contextmenu';
import { BoardDTO, ImageDTO } from 'services/api';
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
import { boardIdSelected } from 'features/gallery/store/boardSlice';
import {
useAddImageToBoardMutation,
useDeleteBoardMutation,
useGetImageDTOQuery,
useUpdateBoardMutation,
} from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useDroppable } from '@dnd-kit/core';
import { AnimatePresence } from 'framer-motion';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { SelectedItemOverlay } from '../SelectedItemOverlay';
interface HoverableBoardProps {
board: BoardDTO;
isSelected: boolean;
}
const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
const dispatch = useAppDispatch();
const { data: coverImage } = useGetImageDTOQuery(
board.cover_image_name ?? skipToken
);
const { board_name, board_id } = board;
const handleSelectBoard = useCallback(() => {
dispatch(boardIdSelected(board_id));
}, [board_id, dispatch]);
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
useUpdateBoardMutation();
const [deleteBoard, { isLoading: isDeleteBoardLoading }] =
useDeleteBoardMutation();
const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] =
useAddImageToBoardMutation();
const handleUpdateBoardName = (newBoardName: string) => {
updateBoard({ board_id, changes: { board_name: newBoardName } });
};
const handleDeleteBoard = useCallback(() => {
deleteBoard(board_id);
}, [board_id, deleteBoard]);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (droppedImage.board_id === board_id) {
return;
}
addImageToBoard({ board_id, image_name: droppedImage.image_name });
},
[addImageToBoard, board_id]
);
const {
isOver,
setNodeRef,
active: isDropActive,
} = useDroppable({
id: `board_droppable_${board_id}`,
data: {
handleDrop,
},
});
return (
<Box sx={{ touchAction: 'none' }}>
<ContextMenu<HTMLDivElement>
menuProps={{ size: 'sm', isLazy: true }}
renderMenu={() => (
<MenuList sx={{ visibility: 'visible !important' }}>
<MenuItem
sx={{ color: 'error.300' }}
icon={<FaTrash />}
onClickCapture={handleDeleteBoard}
>
Delete Board
</MenuItem>
</MenuList>
)}
>
{(ref) => (
<Flex
key={board_id}
userSelect="none"
ref={ref}
sx={{
flexDir: 'column',
justifyContent: 'space-between',
alignItems: 'center',
cursor: 'pointer',
w: 'full',
h: 'full',
}}
>
<Flex
ref={setNodeRef}
onClick={handleSelectBoard}
sx={{
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
w: 'full',
aspectRatio: '1/1',
overflow: 'hidden',
}}
>
{board.cover_image_name && coverImage?.image_url && (
<Image src={coverImage?.image_url} draggable={false} />
)}
{!(board.cover_image_name && coverImage?.image_url) && (
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaFolder} />
)}
<Flex
sx={{
position: 'absolute',
insetInlineEnd: 0,
top: 0,
p: 1,
}}
>
<Badge variant="solid">{board.image_count}</Badge>
</Flex>
<AnimatePresence>
{isSelected && <SelectedItemOverlay />}
</AnimatePresence>
<AnimatePresence>
{isDropActive && <IAIDropOverlay isOver={isOver} />}
</AnimatePresence>
</Flex>
<Box sx={{ width: 'full' }}>
<Editable
defaultValue={board_name}
submitOnBlur={false}
onSubmit={(nextValue) => {
handleUpdateBoardName(nextValue);
}}
>
<EditablePreview
sx={{
color: isSelected ? 'base.50' : 'base.200',
fontWeight: isSelected ? 600 : undefined,
fontSize: 'xs',
textAlign: 'center',
p: 0,
}}
noOfLines={1}
/>
<EditableInput
sx={{
color: 'base.50',
fontSize: 'xs',
borderColor: 'base.500',
p: 0,
outline: 0,
}}
/>
</Editable>
</Box>
</Flex>
)}
</ContextMenu>
</Box>
);
});
HoverableBoard.displayName = 'HoverableBoard';
export default HoverableBoard;

View File

@ -0,0 +1,93 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Box,
Flex,
Spinner,
Text,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { memo, useContext, useRef, useState } from 'react';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useListAllBoardsQuery } from 'services/apiSlice';
const UpdateImageBoardModal = () => {
// const boards = useSelector(selectBoardsAll);
const { data: boards, isFetching } = useListAllBoardsQuery();
const { isOpen, onClose, handleAddToBoard, image } = useContext(
AddImageToBoardContext
);
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const cancelRef = useRef<HTMLButtonElement>(null);
const currentBoard = boards?.find(
(board) => board.board_id === image?.board_id
);
return (
<AlertDialog
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{currentBoard ? 'Move Image to Board' : 'Add Image to Board'}
</AlertDialogHeader>
<AlertDialogBody>
<Box>
<Flex direction="column" gap={3}>
{currentBoard && (
<Text>
Moving this image from{' '}
<strong>{currentBoard.board_name}</strong> to
</Text>
)}
{isFetching ? (
<Spinner />
) : (
<IAIMantineSelect
placeholder="Select Board"
onChange={(v) => setSelectedBoard(v)}
value={selectedBoard}
data={(boards ?? []).map((board) => ({
label: board.board_name,
value: board.board_id,
}))}
/>
)}
</Flex>
</Box>
</AlertDialogBody>
<AlertDialogFooter>
<IAIButton onClick={onClose}>Cancel</IAIButton>
<IAIButton
isDisabled={!selectedBoard}
colorScheme="accent"
onClick={() => {
if (selectedBoard) {
handleAddToBoard(selectedBoard);
}
}}
ml={3}
>
{currentBoard ? 'Move' : 'Add'}
</IAIButton>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};
export default memo(UpdateImageBoardModal);

View File

@ -51,9 +51,12 @@ import { useAppToaster } from 'app/components/Toaster';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
import { DeleteImageButton } from './DeleteImageModal';
import { selectImagesById } from '../store/imagesSlice';
import { RootState } from 'app/store/store';
const currentImageButtonsSelector = createSelector(
[
(state: RootState) => state,
systemSelector,
gallerySelector,
postprocessingSelector,
@ -61,7 +64,7 @@ const currentImageButtonsSelector = createSelector(
lightboxSelector,
activeTabNameSelector,
],
(system, gallery, postprocessing, ui, lightbox, activeTabName) => {
(state, system, gallery, postprocessing, ui, lightbox, activeTabName) => {
const {
isProcessing,
isConnected,
@ -81,6 +84,8 @@ const currentImageButtonsSelector = createSelector(
shouldShowProgressInViewer,
} = ui;
const imageDTO = selectImagesById(state, gallery.selectedImage ?? '');
const { selectedImage } = gallery;
return {
@ -97,10 +102,10 @@ const currentImageButtonsSelector = createSelector(
activeTabName,
isLightboxOpen,
shouldHidePreview,
image: selectedImage,
seed: selectedImage?.metadata?.seed,
prompt: selectedImage?.metadata?.positive_conditioning,
negativePrompt: selectedImage?.metadata?.negative_conditioning,
image: imageDTO,
seed: imageDTO?.metadata?.seed,
prompt: imageDTO?.metadata?.positive_conditioning,
negativePrompt: imageDTO?.metadata?.negative_conditioning,
shouldShowProgressInViewer,
};
},

View File

@ -9,12 +9,12 @@ import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
import { memo, useCallback } from 'react';
import { systemSelector } from 'features/system/store/systemSelectors';
import { configSelector } from '../../system/store/configSelectors';
import { useAppToaster } from 'app/components/Toaster';
import { imageSelected } from '../store/gallerySlice';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
export const imagesSelector = createSelector(
[uiSelector, gallerySelector, systemSelector],
@ -29,7 +29,7 @@ export const imagesSelector = createSelector(
return {
shouldShowImageDetails,
shouldHidePreview,
image: selectedImage,
selectedImage,
progressImage,
shouldShowProgressInViewer,
shouldAntialiasProgressImage,
@ -45,11 +45,23 @@ export const imagesSelector = createSelector(
const CurrentImagePreview = () => {
const {
shouldShowImageDetails,
image,
selectedImage,
progressImage,
shouldShowProgressInViewer,
shouldAntialiasProgressImage,
} = useAppSelector(imagesSelector);
// const image = useAppSelector((state: RootState) =>
// selectImagesById(state, selectedImage ?? '')
// );
const {
data: image,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(selectedImage ?? skipToken);
const dispatch = useAppDispatch();
const handleDrop = useCallback(
@ -57,7 +69,7 @@ const CurrentImagePreview = () => {
if (droppedImage.image_name === image?.image_name) {
return;
}
dispatch(imageSelected(droppedImage));
dispatch(imageSelected(droppedImage.image_name));
},
[dispatch, image?.image_name]
);
@ -98,14 +110,14 @@ const CurrentImagePreview = () => {
}}
>
<IAIDndImage
image={image}
image={selectedImage && image ? image : undefined}
onDrop={handleDrop}
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
isUploadDisabled={true}
/>
</Flex>
)}
{shouldShowImageDetails && image && (
{shouldShowImageDetails && image && selectedImage && (
<Box
sx={{
position: 'absolute',
@ -119,7 +131,7 @@ const CurrentImagePreview = () => {
<ImageMetadataViewer image={image} />
</Box>
)}
{!shouldShowImageDetails && image && (
{!shouldShowImageDetails && image && selectedImage && (
<Box
sx={{
position: 'absolute',

View File

@ -2,7 +2,14 @@ import { Box, Flex, Icon, Image, MenuItem, MenuList } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useContext, useState } from 'react';
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
import {
FaCheck,
FaExpand,
FaFolder,
FaImage,
FaShare,
FaTrash,
} from 'react-icons/fa';
import { ContextMenu } from 'chakra-ui-contextmenu';
import {
resizeAndScaleCanvas,
@ -27,6 +34,8 @@ import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api';
import { useDraggable } from '@dnd-kit/core';
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
export const selector = createSelector(
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
@ -62,17 +71,10 @@ interface HoverableImageProps {
isSelected: boolean;
}
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) =>
prev.image.image_name === next.image.image_name &&
prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const HoverableImage = memo((props: HoverableImageProps) => {
const HoverableImage = (props: HoverableImageProps) => {
const dispatch = useAppDispatch();
const {
activeTabName,
@ -93,6 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { onDelete } = useContext(DeleteImageContext);
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
const handleDelete = useCallback(() => {
onDelete(image);
}, [image, onDelete]);
@ -106,11 +109,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
},
});
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleSelectImage = useCallback(() => {
dispatch(imageSelected(image));
dispatch(imageSelected(image.image_name));
}, [image, dispatch]);
// Recall parameters handlers
@ -168,6 +173,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// dispatch(setIsLightboxOpen(true));
};
const handleAddToBoard = useCallback(() => {
onClickAddToBoard(image);
}, [image, onClickAddToBoard]);
const handleRemoveFromBoard = useCallback(() => {
if (!image.board_id) {
return;
}
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
}, [image.board_id, image.image_name, removeFromBoard]);
const handleOpenInNewTab = () => {
window.open(image.image_url, '_blank');
};
@ -244,6 +260,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
{image.board_id ? 'Change Board' : 'Add to Board'}
</MenuItem>
{image.board_id && (
<MenuItem
icon={<FaFolder />}
onClickCapture={handleRemoveFromBoard}
>
Remove from Board
</MenuItem>
)}
<MenuItem
sx={{ color: 'error.300' }}
icon={<FaTrash />}
@ -339,8 +366,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
</ContextMenu>
</Box>
);
}, memoEqualityCheck);
};
HoverableImage.displayName = 'HoverableImage';
export default HoverableImage;
export default memo(HoverableImage);

View File

@ -1,12 +1,15 @@
import {
Box,
Button,
ButtonGroup,
Flex,
FlexProps,
Grid,
Icon,
Text,
VStack,
forwardRef,
useDisclosure,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
@ -20,6 +23,7 @@ import {
setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn,
setGalleryView,
} from 'features/gallery/store/gallerySlice';
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
@ -53,41 +57,51 @@ import {
selectImagesAll,
} from '../store/imagesSlice';
import { receivedPageOfImages } from 'services/thunks/image';
import BoardsList from './Boards/BoardsList';
import { boardsSelector } from '../store/boardSlice';
import { ChevronUpIcon } from '@chakra-ui/icons';
import { useListAllBoardsQuery } from 'services/apiSlice';
const categorySelector = createSelector(
const itemSelector = createSelector(
[(state: RootState) => state],
(state) => {
const { images } = state;
const { categories } = images;
const { categories, total: allImagesTotal, isLoading } = state.images;
const { selectedBoardId } = state.boards;
const allImages = selectImagesAll(state);
const filteredImages = allImages.filter((i) =>
categories.includes(i.image_category)
);
const images = allImages.filter((i) => {
const isInCategory = categories.includes(i.image_category);
const isInSelectedBoard = selectedBoardId
? i.board_id === selectedBoardId
: true;
return isInCategory && isInSelectedBoard;
});
return {
images: filteredImages,
isLoading: images.isLoading,
areMoreImagesAvailable: filteredImages.length < images.total,
categories: images.categories,
images,
allImagesTotal,
isLoading,
categories,
selectedBoardId,
};
},
defaultSelectorOptions
);
const mainSelector = createSelector(
[gallerySelector, uiSelector],
(gallery, ui) => {
[gallerySelector, uiSelector, boardsSelector],
(gallery, ui, boards) => {
const {
galleryImageMinimumWidth,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
galleryView,
} = gallery;
const { shouldPinGallery } = ui;
return {
shouldPinGallery,
galleryImageMinimumWidth,
@ -95,6 +109,8 @@ const mainSelector = createSelector(
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
galleryView,
selectedBoardId: boards.selectedBoardId,
};
},
defaultSelectorOptions
@ -126,21 +142,44 @@ const ImageGalleryContent = () => {
shouldAutoSwitchToNewImages,
shouldUseSingleGalleryColumn,
selectedImage,
galleryView,
} = useAppSelector(mainSelector);
const { images, areMoreImagesAvailable, isLoading, categories } =
useAppSelector(categorySelector);
const { images, isLoading, allImagesTotal, categories, selectedBoardId } =
useAppSelector(itemSelector);
const { selectedBoard } = useListAllBoardsQuery(undefined, {
selectFromResult: ({ data }) => ({
selectedBoard: data?.find((b) => b.board_id === selectedBoardId),
}),
});
const filteredImagesTotal = useMemo(
() => selectedBoard?.image_count ?? allImagesTotal,
[allImagesTotal, selectedBoard?.image_count]
);
const areMoreAvailable = useMemo(() => {
return images.length < filteredImagesTotal;
}, [images.length, filteredImagesTotal]);
const handleLoadMoreImages = useCallback(() => {
dispatch(receivedPageOfImages());
}, [dispatch]);
dispatch(
receivedPageOfImages({
categories,
boardId: selectedBoardId,
})
);
}, [categories, dispatch, selectedBoardId]);
const handleEndReached = useMemo(() => {
if (areMoreImagesAvailable && !isLoading) {
if (areMoreAvailable && !isLoading) {
return handleLoadMoreImages;
}
return undefined;
}, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]);
}, [areMoreAvailable, handleLoadMoreImages, isLoading]);
const { isOpen: isBoardListOpen, onToggle } = useDisclosure();
const handleChangeGalleryImageMinimumWidth = (v: number) => {
dispatch(setGalleryImageMinimumWidth(v));
@ -172,46 +211,79 @@ const ImageGalleryContent = () => {
const handleClickImagesCategory = useCallback(() => {
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
dispatch(setGalleryView('images'));
}, [dispatch]);
const handleClickAssetsCategory = useCallback(() => {
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
dispatch(setGalleryView('assets'));
}, [dispatch]);
return (
<Flex
<VStack
sx={{
gap: 2,
flexDirection: 'column',
h: 'full',
w: 'full',
borderRadius: 'base',
}}
>
<Flex
ref={resizeObserverRef}
alignItems="center"
justifyContent="space-between"
>
<ButtonGroup isAttached>
<IAIIconButton
tooltip={t('gallery.images')}
aria-label={t('gallery.images')}
onClick={handleClickImagesCategory}
isChecked={categories === IMAGE_CATEGORIES}
<Box sx={{ w: 'full' }}>
<Flex
ref={resizeObserverRef}
sx={{
alignItems: 'center',
justifyContent: 'space-between',
gap: 2,
}}
>
<ButtonGroup isAttached>
<IAIIconButton
tooltip={t('gallery.images')}
aria-label={t('gallery.images')}
onClick={handleClickImagesCategory}
isChecked={galleryView === 'images'}
size="sm"
icon={<FaImage />}
/>
<IAIIconButton
tooltip={t('gallery.assets')}
aria-label={t('gallery.assets')}
onClick={handleClickAssetsCategory}
isChecked={galleryView === 'assets'}
size="sm"
icon={<FaServer />}
/>
</ButtonGroup>
<Flex
as={Button}
onClick={onToggle}
size="sm"
icon={<FaImage />}
/>
<IAIIconButton
tooltip={t('gallery.assets')}
aria-label={t('gallery.assets')}
onClick={handleClickAssetsCategory}
isChecked={categories === ASSETS_CATEGORIES}
size="sm"
icon={<FaServer />}
/>
</ButtonGroup>
<Flex gap={2}>
variant="ghost"
sx={{
w: 'full',
justifyContent: 'center',
alignItems: 'center',
px: 2,
_hover: {
bg: 'base.800',
},
}}
>
<Text
noOfLines={1}
sx={{ w: 'full', color: 'base.200', fontWeight: 600 }}
>
{selectedBoard ? selectedBoard.board_name : 'All Images'}
</Text>
<ChevronUpIcon
sx={{
transform: isBoardListOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
</Flex>
<IAIPopover
triggerComponent={
<IAIIconButton
@ -269,9 +341,12 @@ const ImageGalleryContent = () => {
icon={shouldPinGallery ? <BsPinAngleFill /> : <BsPinAngle />}
/>
</Flex>
</Flex>
<Flex direction="column" gap={2} h="full">
{images.length || areMoreImagesAvailable ? (
<Box>
<BoardsList isOpen={isBoardListOpen} />
</Box>
</Box>
<Flex direction="column" gap={2} h="full" w="full">
{images.length || areMoreAvailable ? (
<>
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
{shouldUseSingleGalleryColumn ? (
@ -280,14 +355,12 @@ const ImageGalleryContent = () => {
data={images}
endReached={handleEndReached}
scrollerRef={(ref) => setScrollerRef(ref)}
itemContent={(index, image) => (
itemContent={(index, item) => (
<Flex sx={{ pb: 2 }}>
<HoverableImage
key={`${image.image_name}-${image.thumbnail_url}`}
image={image}
isSelected={
selectedImage?.image_name === image?.image_name
}
key={`${item.image_name}-${item.thumbnail_url}`}
image={item}
isSelected={selectedImage === item?.image_name}
/>
</Flex>
)}
@ -302,13 +375,11 @@ const ImageGalleryContent = () => {
List: ListContainer,
}}
scrollerRef={setScroller}
itemContent={(index, image) => (
itemContent={(index, item) => (
<HoverableImage
key={`${image.image_name}-${image.thumbnail_url}`}
image={image}
isSelected={
selectedImage?.image_name === image?.image_name
}
key={`${item.image_name}-${item.thumbnail_url}`}
image={item}
isSelected={selectedImage === item?.image_name}
/>
)}
/>
@ -316,12 +387,12 @@ const ImageGalleryContent = () => {
</Box>
<IAIButton
onClick={handleLoadMoreImages}
isDisabled={!areMoreImagesAvailable}
isDisabled={!areMoreAvailable}
isLoading={isLoading}
loadingText="Loading"
flexShrink={0}
>
{areMoreImagesAvailable
{areMoreAvailable
? t('gallery.loadMore')
: t('gallery.allImagesLoaded')}
</IAIButton>
@ -350,7 +421,7 @@ const ImageGalleryContent = () => {
</Flex>
)}
</Flex>
</Flex>
</VStack>
);
};

View File

@ -93,19 +93,11 @@ type ImageMetadataViewerProps = {
image: ImageDTO;
};
// TODO: I don't know if this is needed.
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.image_name === next.image.image_name;
// TODO: Show more interesting information in this component.
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const {
recallBothPrompts,
@ -333,8 +325,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
</Flex>
</Flex>
);
}, memoEqualityCheck);
};
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
export default ImageMetadataViewer;
export default memo(ImageMetadataViewer);

View File

@ -42,7 +42,7 @@ export const nextPrevImageButtonsSelector = createSelector(
}
const currentImageIndex = filteredImageIds.findIndex(
(i) => i === selectedImage.image_name
(i) => i === selectedImage
);
const nextImageIndex = clamp(
@ -71,6 +71,8 @@ export const nextPrevImageButtonsSelector = createSelector(
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
nextImage,
prevImage,
nextImageId,
prevImageId,
};
},
{
@ -84,7 +86,7 @@ const NextPrevImageButtons = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { isOnFirstImage, isOnLastImage, nextImage, prevImage } =
const { isOnFirstImage, isOnLastImage, nextImageId, prevImageId } =
useAppSelector(nextPrevImageButtonsSelector);
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
@ -99,19 +101,19 @@ const NextPrevImageButtons = () => {
}, []);
const handlePrevImage = useCallback(() => {
dispatch(imageSelected(prevImage));
}, [dispatch, prevImage]);
dispatch(imageSelected(prevImageId));
}, [dispatch, prevImageId]);
const handleNextImage = useCallback(() => {
dispatch(imageSelected(nextImage));
}, [dispatch, nextImage]);
dispatch(imageSelected(nextImageId));
}, [dispatch, nextImageId]);
useHotkeys(
'left',
() => {
handlePrevImage();
},
[prevImage]
[prevImageId]
);
useHotkeys(
@ -119,7 +121,7 @@ const NextPrevImageButtons = () => {
() => {
handleNextImage();
},
[nextImage]
[nextImageId]
);
return (

View File

@ -0,0 +1,26 @@
import { motion } from 'framer-motion';
export const SelectedItemOverlay = () => (
<motion.div
initial={{
opacity: 0,
}}
animate={{
opacity: 1,
transition: { duration: 0.1 },
}}
exit={{
opacity: 0,
transition: { duration: 0.1 },
}}
style={{
position: 'absolute',
top: 0,
insetInlineStart: 0,
width: '100%',
height: '100%',
boxShadow: 'inset 0px 0px 0px 2px var(--invokeai-colors-accent-300)',
borderRadius: 'var(--invokeai-radii-base)',
}}
/>
);

View File

@ -0,0 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { selectBoardsAll } from './boardSlice';
export const boardSelector = (state: RootState) => state.boards.entities;
export const searchBoardsSelector = createSelector(
(state: RootState) => state,
(state) => {
const {
boards: { searchText },
} = state;
if (!searchText) {
// If no search text provided, return all entities
return selectBoardsAll(state);
}
return selectBoardsAll(state).filter((i) =>
i.board_name.toLowerCase().includes(searchText.toLowerCase())
);
}
);

View File

@ -0,0 +1,47 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { api } from 'services/apiSlice';
type BoardsState = {
searchText: string;
selectedBoardId?: string;
updateBoardModalOpen: boolean;
};
export const initialBoardsState: BoardsState = {
updateBoardModalOpen: false,
searchText: '',
};
const boardsSlice = createSlice({
name: 'boards',
initialState: initialBoardsState,
reducers: {
boardIdSelected: (state, action: PayloadAction<string | undefined>) => {
state.selectedBoardId = action.payload;
},
setBoardSearchText: (state, action: PayloadAction<string>) => {
state.searchText = action.payload;
},
setUpdateBoardModalOpen: (state, action: PayloadAction<boolean>) => {
state.updateBoardModalOpen = action.payload;
},
},
extraReducers: (builder) => {
builder.addMatcher(
api.endpoints.deleteBoard.matchFulfilled,
(state, action) => {
if (action.meta.arg.originalArgs === state.selectedBoardId) {
state.selectedBoardId = undefined;
}
}
);
},
});
export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } =
boardsSlice.actions;
export const boardsSelector = (state: RootState) => state.boards;
export default boardsSlice.reducer;

View File

@ -1,17 +1,16 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api';
import { imageUpserted } from './imagesSlice';
import { imageUrlsReceived } from 'services/thunks/image';
type GalleryImageObjectFitType = 'contain' | 'cover';
export interface GalleryState {
selectedImage?: ImageDTO;
selectedImage?: string;
galleryImageMinimumWidth: number;
galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean;
shouldUseSingleGalleryColumn: boolean;
galleryView: 'images' | 'assets' | 'boards';
}
export const initialGalleryState: GalleryState = {
@ -19,13 +18,14 @@ export const initialGalleryState: GalleryState = {
galleryImageObjectFit: 'cover',
shouldAutoSwitchToNewImages: true,
shouldUseSingleGalleryColumn: false,
galleryView: 'images',
};
export const gallerySlice = createSlice({
name: 'gallery',
initialState: initialGalleryState,
reducers: {
imageSelected: (state, action: PayloadAction<ImageDTO | undefined>) => {
imageSelected: (state, action: PayloadAction<string | undefined>) => {
state.selectedImage = action.payload;
// TODO: if the user selects an image, disable the auto switch?
// state.shouldAutoSwitchToNewImages = false;
@ -48,6 +48,12 @@ export const gallerySlice = createSlice({
) => {
state.shouldUseSingleGalleryColumn = action.payload;
},
setGalleryView: (
state,
action: PayloadAction<'images' | 'assets' | 'boards'>
) => {
state.galleryView = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(imageUpserted, (state, action) => {
@ -55,17 +61,17 @@ export const gallerySlice = createSlice({
state.shouldAutoSwitchToNewImages &&
action.payload.image_category === 'general'
) {
state.selectedImage = action.payload;
state.selectedImage = action.payload.image_name;
}
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
// const { image_name, image_url, thumbnail_url } = action.payload;
if (state.selectedImage?.image_name === image_name) {
state.selectedImage.image_url = image_url;
state.selectedImage.thumbnail_url = thumbnail_url;
}
});
// if (state.selectedImage?.image_name === image_name) {
// state.selectedImage.image_url = image_url;
// state.selectedImage.thumbnail_url = thumbnail_url;
// }
// });
},
});
@ -75,6 +81,7 @@ export const {
setGalleryImageObjectFit,
setShouldAutoSwitchToNewImages,
setShouldUseSingleGalleryColumn,
setGalleryView,
} = gallerySlice.actions;
export default gallerySlice.reducer;

View File

@ -11,7 +11,6 @@ import { dateComparator } from 'common/util/dateComparator';
import { keyBy } from 'lodash-es';
import {
imageDeleted,
imageMetadataReceived,
imageUrlsReceived,
receivedPageOfImages,
} from 'services/thunks/image';
@ -74,11 +73,21 @@ const imagesSlice = createSlice({
});
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
state.isLoading = false;
const { boardId, categories, imageOrigin, isIntermediate } =
action.meta.arg;
const { items, offset, limit, total } = action.payload;
imagesAdapter.upsertMany(state, items);
if (!categories?.includes('general') || boardId) {
// need to skip updating the total images count if the images recieved were for a specific board
// TODO: this doesn't work when on the Asset tab/category...
return;
}
state.offset = offset;
state.limit = limit;
state.total = total;
imagesAdapter.upsertMany(state, items);
});
builder.addCase(imageDeleted.pending, (state, action) => {
// Image deleted
@ -154,3 +163,16 @@ export const selectFilteredImagesIds = createSelector(
.map((i) => i.image_name);
}
);
// export const selectImageById = createSelector(
// (state: RootState, imageId) => state,
// (state) => {
// const {
// images: { categories },
// } = state;
// return selectImagesAll(state)
// .filter((i) => categories.includes(i.image_category))
// .map((i) => i.image_name);
// }
// );

View File

@ -11,6 +11,8 @@ import { FieldComponentProps } from './types';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
import { Flex } from '@chakra-ui/react';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
@ -19,9 +21,16 @@ const ImageInputFieldComponent = (
const dispatch = useAppDispatch();
const {
data: image,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(field.value ?? skipToken);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (field.value?.image_name === droppedImage.image_name) {
if (field.value === droppedImage.image_name) {
return;
}
@ -29,11 +38,11 @@ const ImageInputFieldComponent = (
fieldValueChanged({
nodeId,
fieldName: field.name,
value: droppedImage,
value: droppedImage.image_name,
})
);
},
[dispatch, field.name, field.value?.image_name, nodeId]
[dispatch, field.name, field.value, nodeId]
);
const handleReset = useCallback(() => {
@ -56,7 +65,7 @@ const ImageInputFieldComponent = (
}}
>
<IAIDndImage
image={field.value}
image={image}
onDrop={handleDrop}
onReset={handleReset}
resetIconSize="sm"

View File

@ -1,28 +1,18 @@
import { Select } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ModelInputFieldTemplate,
ModelInputFieldValue,
} from 'features/nodes/types/types';
import { selectModelsIds } from 'features/system/store/modelSlice';
import { isEqual } from 'lodash-es';
import { ChangeEvent, memo } from 'react';
import { FieldComponentProps } from './types';
const availableModelsSelector = createSelector(
[selectModelsIds],
(allModelNames) => {
return { allModelNames };
// return map(modelList, (_, name) => name);
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
import { memo, useCallback, useEffect, useMemo } from 'react';
import { FieldComponentProps } from './types';
import { forEach, isString } from 'lodash-es';
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useTranslation } from 'react-i18next';
import { useListModelsQuery } from 'services/apiSlice';
const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
@ -30,28 +20,82 @@ const ModelInputFieldComponent = (
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { allModelNames } = useAppSelector(availableModelsSelector);
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.value,
})
);
};
const data = useMemo(() => {
if (!pipelineModels) {
return [];
}
const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: BASE_MODEL_NAME_MAP[model.base_model],
});
});
return data;
}, [pipelineModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
[pipelineModels?.entities, pipelineModels?.ids, field.value]
);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: v,
})
);
},
[dispatch, field.name, nodeId]
);
useEffect(() => {
if (field.value && pipelineModels?.ids.includes(field.value)) {
return;
}
const firstModel = pipelineModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleValueChanged(firstModel);
}, [field.value, handleValueChanged, pipelineModels?.ids]);
return (
<Select
<IAIMantineSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model &&
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
}
value={field.value}
placeholder="Pick one"
data={data}
onChange={handleValueChanged}
value={field.value || allModelNames[0]}
>
{allModelNames.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
/>
);
};

View File

@ -101,21 +101,6 @@ const nodesSlice = createSlice({
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload;
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
state.nodes.forEach((node) => {
forEach(node.data.inputs, (input) => {
if (input.type === 'image') {
if (input.value?.image_name === image_name) {
input.value.image_url = image_url;
input.value.thumbnail_url = thumbnail_url;
}
}
});
});
});
},
});

View File

@ -214,7 +214,7 @@ export type VaeInputFieldValue = FieldValueBase & {
export type ImageInputFieldValue = FieldValueBase & {
type: 'image';
value?: ImageDTO;
value?: string;
};
export type ModelInputFieldValue = FieldValueBase & {

View File

@ -65,15 +65,13 @@ export const addControlNetToLinearGraph = (
if (processedControlImage && processorType !== 'none') {
// We've already processed the image in the app, so we can just use the processed image
const { image_name } = processedControlImage;
controlNetNode.image = {
image_name,
image_name: processedControlImage,
};
} else if (controlImage) {
// The control image is preprocessed
const { image_name } = controlImage;
controlNetNode.image = {
image_name,
image_name: controlImage,
};
} else {
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly

View File

@ -23,6 +23,7 @@ import {
} from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' });
@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
const {
positivePrompt,
negativePrompt,
model: model_name,
model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = (
id: NOISE,
},
[MODEL_LOADER]: {
type: 'sd1_model_loader',
type: 'pipeline_model_loader',
id: MODEL_LOADER,
model_name,
model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',

View File

@ -17,6 +17,7 @@ import {
INPAINT_GRAPH,
INPAINT,
} from './constants';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' });
@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = (
const {
positivePrompt,
negativePrompt,
model: model_name,
model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = (
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const model = modelIdToPipelineModelField(modelId);
const graph: NonNullableGraph = {
id: INPAINT_GRAPH,
nodes: {
@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = (
prompt: negativePrompt,
},
[MODEL_LOADER]: {
type: 'sd1_model_loader',
type: 'pipeline_model_loader',
id: MODEL_LOADER,
model_name,
model,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',

View File

@ -14,6 +14,7 @@ import {
TEXT_TO_LATENTS,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/**
* Builds the Canvas tab's Text to Image graph.
@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
const {
positivePrompt,
negativePrompt,
model: model_name,
model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const model = modelIdToPipelineModelField(modelId);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = (
steps,
},
[MODEL_LOADER]: {
type: 'sd1_model_loader',
type: 'pipeline_model_loader',
id: MODEL_LOADER,
model_name,
model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',

View File

@ -22,6 +22,7 @@ import {
} from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
const moduleLog = log.child({ namespace: 'nodes' });
@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = (
const {
positivePrompt,
negativePrompt,
model: model_name,
model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = (
throw new Error('No initial image found in state');
}
const model = modelIdToPipelineModelField(modelId);
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: IMAGE_TO_IMAGE_GRAPH,
@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = (
id: NOISE,
},
[MODEL_LOADER]: {
type: 'sd1_model_loader',
type: 'pipeline_model_loader',
id: MODEL_LOADER,
model_name,
model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
@ -274,7 +277,7 @@ export const buildLinearImageToImageGraph = (
id: RESIZE,
type: 'img_resize',
image: {
image_name: initialImage.image_name,
image_name: initialImage.imageName,
},
is_intermediate: true,
width,
@ -311,7 +314,7 @@ export const buildLinearImageToImageGraph = (
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
image_name: initialImage.image_name,
image_name: initialImage.imageName,
});
// Pass the image's dimensions to the `NOISE` node

View File

@ -1,6 +1,10 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
import {
BaseModelType,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api';
import {
ITERATE,
LATENTS_TO_IMAGE,
@ -14,6 +18,7 @@ import {
TEXT_TO_LATENTS,
} from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
type TextToImageGraphOverrides = {
width: number;
@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = (
const {
positivePrompt,
negativePrompt,
model: model_name,
model: modelId,
cfgScale: cfg_scale,
scheduler,
steps,
@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = (
shouldRandomizeSeed,
} = state.generation;
const model = modelIdToPipelineModelField(modelId);
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = (
steps,
},
[MODEL_LOADER]: {
type: 'sd1_model_loader',
type: 'pipeline_model_loader',
id: MODEL_LOADER,
model_name,
model,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',

View File

@ -1,9 +1,10 @@
import { Graph } from 'services/api';
import { v4 as uuidv4 } from 'uuid';
import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es';
import { cloneDeep, omit, reduce } from 'lodash-es';
import { RootState } from 'app/store/store';
import { InputFieldValue } from 'features/nodes/types/types';
import { AnyInvocation } from 'services/events/types';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
/**
* We need to do special handling for some fields
@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
}
}
if (field.type === 'model') {
if (field.value) {
return modelIdToPipelineModelField(field.value);
}
}
return field.value;
};

View File

@ -7,7 +7,7 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate';
export const MODEL_LOADER = 'model_loader';
export const MODEL_LOADER = 'pipeline_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';

View File

@ -0,0 +1,18 @@
import { BaseModelType, PipelineModelField } from 'services/api';
/**
* Crudely converts a model id to a pipeline model field
* TODO: Make better
*/
export const modelIdToPipelineModelField = (
modelId: string
): PipelineModelField => {
const [base_model, model_type, model_name] = modelId.split('/');
const field: PipelineModelField = {
base_model: base_model as BaseModelType,
model_name,
};
return field;
};

View File

@ -57,7 +57,7 @@ export const buildImg2ImgNode = (
}
imageToImageNode.image = {
image_name: initialImage.image_name,
image_name: initialImage.imageName,
};
}

View File

@ -6,7 +6,7 @@ import ParamScheduler from './ParamScheduler';
const ParamSchedulerAndModel = () => {
return (
<Flex gap={3} w="full">
<Box w="20rem">
<Box w="25rem">
<ParamScheduler />
</Box>
<Box w="full">

View File

@ -10,7 +10,9 @@ import { generationSelector } from 'features/parameters/store/generationSelector
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDndImage from 'common/components/IAIDndImage';
import { ImageDTO } from 'services/api';
import { IAIImageFallback } from 'common/components/IAIImageFallback';
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
import { useGetImageDTOQuery } from 'services/apiSlice';
import { skipToken } from '@reduxjs/toolkit/dist/query';
const selector = createSelector(
[generationSelector],
@ -27,14 +29,21 @@ const InitialImagePreview = () => {
const { initialImage } = useAppSelector(selector);
const dispatch = useAppDispatch();
const {
data: image,
isLoading,
isError,
isSuccess,
} = useGetImageDTOQuery(initialImage?.imageName ?? skipToken);
const handleDrop = useCallback(
(droppedImage: ImageDTO) => {
if (droppedImage.image_name === initialImage?.image_name) {
if (droppedImage.image_name === initialImage?.imageName) {
return;
}
dispatch(initialImageChanged(droppedImage));
},
[dispatch, initialImage?.image_name]
[dispatch, initialImage]
);
const handleReset = useCallback(() => {
@ -53,10 +62,10 @@ const InitialImagePreview = () => {
}}
>
<IAIDndImage
image={initialImage}
image={image}
onDrop={handleDrop}
onReset={handleReset}
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
postUploadAction={{ type: 'SET_INITIAL_IMAGE' }}
withResetIcon
/>

View File

@ -1,10 +1,9 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
import { configChanged } from 'features/system/store/configSlice';
import { clamp, sortBy } from 'lodash-es';
import { clamp } from 'lodash-es';
import { ImageDTO } from 'services/api';
import { imageUrlsReceived } from 'services/thunks/image';
import { receivedModels } from 'services/thunks/model';
import {
CfgScaleParam,
HeightParam,
@ -17,14 +16,13 @@ import {
StrengthParam,
WidthParam,
} from './parameterZodSchemas';
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
export interface GenerationState {
cfgScale: CfgScaleParam;
height: HeightParam;
img2imgStrength: StrengthParam;
infillMethod: string;
initialImage?: ImageDTO;
initialImage?: { imageName: string; width: number; height: number };
iterations: number;
perlin: number;
positivePrompt: PositivePromptParam;
@ -212,35 +210,20 @@ export const generationSlice = createSlice({
state.shouldUseNoiseSettings = action.payload;
},
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
state.initialImage = action.payload;
const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height };
},
modelSelected: (state, action: PayloadAction<string>) => {
state.model = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(receivedModels.fulfilled, (state, action) => {
if (!state.model) {
const firstModel = sortBy(action.payload, 'name')[0];
state.model = firstModel.name;
}
});
builder.addCase(configChanged, (state, action) => {
const defaultModel = action.payload.sd?.defaultModel;
if (defaultModel && !state.model) {
state.model = defaultModel;
}
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_url, thumbnail_url } = action.payload;
if (state.initialImage?.image_name === image_name) {
state.initialImage.image_url = image_url;
state.initialImage.thumbnail_url = thumbnail_url;
}
});
},
});

View File

@ -154,3 +154,17 @@ export type StrengthParam = z.infer<typeof zStrength>;
*/
export const isValidStrength = (val: unknown): val is StrengthParam =>
zStrength.safeParse(val).success;
// /**
// * Zod schema for BaseModelType
// */
// export const zBaseModelType = z.enum(['sd-1', 'sd-2']);
// /**
// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI.
// */
// export type BaseModelType = z.infer<typeof zBaseModelType>;
// /**
// * Validates/type-guards a value as a base model type
// */
// export const isValidBaseModelType = (val: unknown): val is BaseModelType =>
// zBaseModelType.safeParse(val).success;

View File

@ -1,44 +1,59 @@
import { createSelector } from '@reduxjs/toolkit';
import { isEqual } from 'lodash-es';
import { memo, useCallback } from 'react';
import { memo, useCallback, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { modelSelected } from 'features/parameters/store/generationSlice';
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
const selector = createSelector(
[(state: RootState) => state, generationSelector],
(state, generation) => {
const selectedModel = selectModelsById(state, generation.model);
import { forEach, isString } from 'lodash-es';
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useListModelsQuery } from 'services/apiSlice';
const modelData = selectModelsAll(state)
.map<IAISelectDataType>((m) => ({
value: m.name,
label: m.name,
}))
.sort((a, b) => a.label.localeCompare(b.label));
return {
selectedModel,
modelData,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const MODEL_TYPE_MAP = {
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
};
const ModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { selectedModel, modelData } = useAppSelector(selector);
const selectedModelId = useAppSelector(
(state: RootState) => state.generation.model
);
const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline',
});
const data = useMemo(() => {
if (!pipelineModels) {
return [];
}
const data: SelectItem[] = [];
forEach(pipelineModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [pipelineModels]);
const selectedModel = useMemo(
() => pipelineModels?.entities[selectedModelId],
[pipelineModels?.entities, selectedModelId]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
@ -49,13 +64,27 @@ const ModelSelect = () => {
[dispatch]
);
useEffect(() => {
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
return;
}
const firstModel = pipelineModels?.ids[0];
if (!isString(firstModel)) {
return;
}
handleChangeModel(firstModel);
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
return (
<IAIMantineSelect
tooltip={selectedModel?.description}
label={t('modelManager.model')}
value={selectedModel?.name ?? ''}
value={selectedModelId}
placeholder="Pick one"
data={modelData}
data={data}
onChange={handleChangeModel}
/>
);

View File

@ -1,6 +1,5 @@
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({
export default function SettingsSchedulers() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const enabledSchedulers = useAppSelector(

View File

@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors';
const isApplicationReadySelector = createSelector(
[systemSelector, configSelector],
(system, config) => {
const { wereModelsReceived, wasSchemaParsed } = system;
const { wasSchemaParsed } = system;
const { disabledTabs } = config;
return {
disabledTabs,
wereModelsReceived,
wasSchemaParsed,
};
}
@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector(
* Checks if the application is ready to be used, i.e. if the initial startup process is finished.
*/
export const useIsApplicationReady = () => {
const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector(
const { disabledTabs, wasSchemaParsed } = useAppSelector(
isApplicationReadySelector
);
const isApplicationReady = useMemo(() => {
if (!wereModelsReceived) {
return false;
}
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
return false;
}
return true;
}, [disabledTabs, wereModelsReceived, wasSchemaParsed]);
}, [disabledTabs, wasSchemaParsed]);
return isApplicationReady;
};

View File

@ -1,3 +0,0 @@
import { RootState } from 'app/store/store';
export const modelSelector = (state: RootState) => state.models;

View File

@ -1,47 +0,0 @@
import { createEntityAdapter } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
import { receivedModels } from 'services/thunks/model';
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
name: string;
};
export const modelsAdapter = createEntityAdapter<Model>({
selectId: (model) => model.name,
sortComparer: (a, b) => a.name.localeCompare(b.name),
});
export const initialModelsState = modelsAdapter.getInitialState();
export type ModelsState = typeof initialModelsState;
export const modelsSlice = createSlice({
name: 'models',
initialState: initialModelsState,
reducers: {
modelAdded: modelsAdapter.upsertOne,
},
extraReducers(builder) {
/**
* Received Models - FULFILLED
*/
builder.addCase(receivedModels.fulfilled, (state, action) => {
const models = action.payload;
modelsAdapter.setAll(state, models);
});
},
});
export const {
selectAll: selectModelsAll,
selectById: selectModelsById,
selectEntities: selectModelsEntities,
selectIds: selectModelsIds,
selectTotal: selectModelsTotal,
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
export const { modelAdded } = modelsSlice.actions;
export default modelsSlice.reducer;

View File

@ -1,6 +0,0 @@
import { ModelsState } from './modelSlice';
/**
* Models slice persist denylist
*/
export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids'];

View File

@ -1,20 +1,12 @@
import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import { ProgressImage } from 'services/events/types';
import { makeToast } from '../../../app/components/Toaster';
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
import { receivedModels } from 'services/thunks/model';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { LogLevelName } from 'roarr';
import { InvokeLogLevel } from 'app/logging/useLogger';
import { TFuncKey } from 'i18next';
import { t } from 'i18next';
import { userInvoked } from 'app/store/actions';
import { LANGUAGES } from '../components/LanguagePicker';
import { imageUploaded } from 'services/thunks/image';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { TFuncKey, t } from 'i18next';
import { LogLevelName } from 'roarr';
import {
appSocketConnected,
appSocketDisconnected,
@ -26,6 +18,11 @@ import {
appSocketSubscribed,
appSocketUnsubscribed,
} from 'services/events/actions';
import { ProgressImage } from 'services/events/types';
import { imageUploaded } from 'services/thunks/image';
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker';
export type CancelStrategy = 'immediate' | 'scheduled';
@ -95,6 +92,7 @@ export interface SystemState {
shouldAntialiasProgressImage: boolean;
language: keyof typeof LANGUAGES;
isUploading: boolean;
boardIdToAddTo?: string;
}
export const initialSystemState: SystemState = {
@ -225,6 +223,7 @@ export const systemSlice = createSlice({
*/
builder.addCase(appSocketSubscribed, (state, action) => {
state.sessionId = action.payload.sessionId;
state.boardIdToAddTo = action.payload.boardId;
state.canceledSession = '';
});
@ -233,6 +232,7 @@ export const systemSlice = createSlice({
*/
builder.addCase(appSocketUnsubscribed, (state) => {
state.sessionId = null;
state.boardIdToAddTo = undefined;
});
/**
@ -376,13 +376,6 @@ export const systemSlice = createSlice({
);
});
/**
* Received available models from the backend
*/
builder.addCase(receivedModels.fulfilled, (state) => {
state.wereModelsReceived = true;
});
/**
* OpenAPI schema was parsed
*/

View File

@ -8,6 +8,10 @@ export type { OpenAPIConfig } from './core/OpenAPI';
export type { AddInvocation } from './models/AddInvocation';
export type { BaseModelType } from './models/BaseModelType';
export type { BoardChanges } from './models/BoardChanges';
export type { BoardDTO } from './models/BoardDTO';
export type { Body_create_board_image } from './models/Body_create_board_image';
export type { Body_remove_board_image } from './models/Body_remove_board_image';
export type { Body_upload_image } from './models/Body_upload_image';
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
export type { CkptModelInfo } from './models/CkptModelInfo';
@ -21,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField';
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
export type { ControlField } from './models/ControlField';
export type { ControlNetInvocation } from './models/ControlNetInvocation';
export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
export type { ControlNetModelFormat } from './models/ControlNetModelFormat';
export type { ControlOutput } from './models/ControlOutput';
export type { CreateModelRequest } from './models/CreateModelRequest';
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
@ -63,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput';
export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config';
export type { IterateInvocation } from './models/IterateInvocation';
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
export type { LatentsField } from './models/LatentsField';
@ -83,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { LoraInfo } from './models/LoraInfo';
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
export type { LoRAModelConfig } from './models/LoRAModelConfig';
export type { LoRAModelFormat } from './models/LoRAModelFormat';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput';
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
@ -98,12 +98,15 @@ export type { MultiplyInvocation } from './models/MultiplyInvocation';
export type { NoiseInvocation } from './models/NoiseInvocation';
export type { NoiseOutput } from './models/NoiseOutput';
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
export type { OffsetPaginatedResults_BoardDTO_ } from './models/OffsetPaginatedResults_BoardDTO_';
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
export type { PipelineModelField } from './models/PipelineModelField';
export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation';
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation';
@ -115,20 +118,28 @@ export type { ResourceOrigin } from './models/ResourceOrigin';
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
export type { ShowImageInvocation } from './models/ShowImageInvocation';
export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat';
export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat';
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
export type { SubModelType } from './models/SubModelType';
export type { SubtractInvocation } from './models/SubtractInvocation';
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
export type { UNetField } from './models/UNetField';
export type { UpscaleInvocation } from './models/UpscaleInvocation';
export type { VaeField } from './models/VaeField';
export type { VaeModelConfig } from './models/VaeModelConfig';
export type { VaeModelFormat } from './models/VaeModelFormat';
export type { VaeRepo } from './models/VaeRepo';
export type { ValidationError } from './models/ValidationError';
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
export { BoardsService } from './services/BoardsService';
export { ImagesService } from './services/ImagesService';
export { ModelsService } from './services/ModelsService';
export { SessionsService } from './services/SessionsService';

View File

@ -0,0 +1,15 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type BoardChanges = {
/**
* The board's new name.
*/
board_name?: string;
/**
* The name of the board's new cover image.
*/
cover_image_name?: string;
};

View File

@ -0,0 +1,38 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Deserialized board record with cover image URL and image count.
*/
export type BoardDTO = {
/**
* The unique ID of the board.
*/
board_id: string;
/**
* The name of the board.
*/
board_name: string;
/**
* The created timestamp of the board.
*/
created_at: string;
/**
* The updated timestamp of the board.
*/
updated_at: string;
/**
* The deleted timestamp of the board.
*/
deleted_at?: string;
/**
* The name of the board's cover image.
*/
cover_image_name?: string;
/**
* The number of images in the board.
*/
image_count: number;
};

View File

@ -0,0 +1,15 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Body_create_board_image = {
/**
* The id of the board to add to
*/
board_id: string;
/**
* The name of the image to add
*/
image_name: string;
};

View File

@ -0,0 +1,15 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Body_remove_board_image = {
/**
* The id of the board
*/
board_id: string;
/**
* The name of the image to remove
*/
image_name: string;
};

View File

@ -0,0 +1,18 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { BaseModelType } from './BaseModelType';
import type { ControlNetModelFormat } from './ControlNetModelFormat';
import type { ModelError } from './ModelError';
export type ControlNetModelConfig = {
name: string;
base_model: BaseModelType;
type: 'controlnet';
path: string;
description?: string;
model_format: ControlNetModelFormat;
error?: ModelError;
};

View File

@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* An enumeration.
*/
export type ControlNetModelFormat = 'checkpoint' | 'diffusers';

Some files were not shown because too many files have changed in this diff Show More