mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into install-script-python-version-error-prompt-fix
This commit is contained in:
commit
df1907e849
@ -2,8 +2,17 @@
|
||||
|
||||
from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
from invokeai.app.services.board_images import (
|
||||
BoardImagesService,
|
||||
BoardImagesServiceDependencies,
|
||||
)
|
||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
@ -57,7 +66,7 @@ class ApiDependencies:
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
@ -72,14 +81,40 @@ class ApiDependencies:
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
board_images = BoardImagesService(
|
||||
services=BoardImagesServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
images = ImageService(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
services=ImageServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
@ -87,6 +122,8 @@ class ApiDependencies:
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
|
69
invokeai/app/api/routers/board_images.py
Normal file
69
invokeai/app/api/routers/board_images.py
Normal file
@ -0,0 +1,69 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="create_board_image",
|
||||
responses={
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def create_board_image(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
):
|
||||
"""Creates a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to add to board")
|
||||
|
||||
@board_images_router.delete(
|
||||
"/",
|
||||
operation_id="remove_board_image",
|
||||
responses={
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def remove_board_image(
|
||||
board_id: str = Body(description="The id of the board"),
|
||||
image_name: str = Body(description="The name of the image to remove"),
|
||||
):
|
||||
"""Deletes a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
|
||||
@board_images_router.get(
|
||||
"/{board_id}",
|
||||
operation_id="list_board_images",
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_board_images(
|
||||
board_id: str = Path(description="The id of the board"),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of boards per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images for a board"""
|
||||
|
||||
results = ApiDependencies.invoker.services.board_images.get_images_for_board(
|
||||
board_id,
|
||||
)
|
||||
return results
|
||||
|
108
invokeai/app/api/routers/boards.py
Normal file
108
invokeai/app/api/routers/boards.py
Normal file
@ -0,0 +1,108 @@
|
||||
from typing import Optional, Union
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
|
||||
|
||||
|
||||
@boards_router.post(
|
||||
"/",
|
||||
operation_id="create_board",
|
||||
responses={
|
||||
201: {"description": "The board was created successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def create_board(
|
||||
board_name: str = Query(description="The name of the board to create"),
|
||||
) -> BoardDTO:
|
||||
"""Creates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
|
||||
|
||||
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
|
||||
async def get_board(
|
||||
board_id: str = Path(description="The id of board to get"),
|
||||
) -> BoardDTO:
|
||||
"""Gets a board"""
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
|
||||
@boards_router.patch(
|
||||
"/{board_id}",
|
||||
operation_id="update_board",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The board was updated successfully",
|
||||
},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def update_board(
|
||||
board_id: str = Path(description="The id of board to update"),
|
||||
changes: BoardChanges = Body(description="The changes to apply to the board"),
|
||||
) -> BoardDTO:
|
||||
"""Updates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(
|
||||
board_id=board_id, changes=changes
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board")
|
||||
async def delete_board(
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
) -> None:
|
||||
"""Deletes a board"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@boards_router.get(
|
||||
"/",
|
||||
operation_id="list_boards",
|
||||
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
|
||||
)
|
||||
async def list_boards(
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||
limit: Optional[int] = Query(
|
||||
default=None, description="The number of boards per page"
|
||||
),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
if all:
|
||||
return ApiDependencies.invoker.services.boards.get_all()
|
||||
elif offset is not None and limit is not None:
|
||||
return ApiDependencies.invoker.services.boards.get_many(
|
||||
offset,
|
||||
limit,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
|
||||
)
|
@ -221,6 +221,9 @@ async def list_images_with_metadata(
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images"
|
||||
),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None, description="The board id to filter by"
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
@ -232,6 +235,7 @@ async def list_images_with_metadata(
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models import get_all_model_configs
|
||||
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
|
||||
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
||||
models: list[MODEL_CONFIGS]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@ -72,10 +71,10 @@ class ModelsList(BaseModel):
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models(
|
||||
base_model: BaseModelType = Query(
|
||||
base_model: Optional[BaseModelType] = Query(
|
||||
default=None, description="Base model"
|
||||
),
|
||||
model_type: ModelType = Query(
|
||||
model_type: Optional[ModelType] = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
) -> ModelsList:
|
||||
|
@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config)
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images
|
||||
from .api.routers import sessions, models, images, boards, board_images
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
@ -116,6 +120,22 @@ def custom_openapi():
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
|
||||
if name in openapi_schema["components"]["schemas"]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
||||
openapi_schema["components"]["schemas"][name] = dict(
|
||||
title=name,
|
||||
description="An enumeration.",
|
||||
type="string",
|
||||
enum=list(v.value for v in model_config_format_enum),
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
#fmt: on
|
||||
|
||||
|
||||
class SD1ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
class PipelineModelField(BaseModel):
|
||||
"""Pipeline model field"""
|
||||
|
||||
type: Literal["sd1_model_loader"] = "sd1_model_loader"
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
|
||||
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a pipeline model, outputting its submodels."""
|
||||
|
||||
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||
|
||||
model: PipelineModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation):
|
||||
"ui": {
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model_name": "model" # TODO: rename to model_name?
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = BaseModelType.StableDiffusion1 # TODO:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Pipeline
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation):
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: optimize(less code copy)
|
||||
class SD2ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
|
||||
type: Literal["sd2_model_loader"] = "sd2_model_loader"
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model_name": "model" # TODO: rename to model_name?
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = BaseModelType.StableDiffusion2 # TODO:
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
)
|
||||
|
254
invokeai/app/services/board_image_record_storage.py
Normal file
254
invokeai/app/services/board_image_record_storage.py
Normal file
@ -0,0 +1,254 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Union, cast
|
||||
from invokeai.app.services.board_record_storage import BoardRecord
|
||||
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
|
||||
class BoardImageRecordStorageBase(ABC):
|
||||
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets images for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_count_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> int:
|
||||
"""Gets the number of images for a board."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `board_images` junction table."""
|
||||
|
||||
# Create the `board_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
|
||||
""",
|
||||
(board_id, image_name, board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE board_id = ? AND image_name = ?;
|
||||
""",
|
||||
(board_id, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
# TODO: this isn't paginated yet?
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
return count
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
142
invokeai/app/services/board_images.py
Normal file
142
invokeai/app/services/board_images.py
Normal file
@ -0,0 +1,142 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import List, Union
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardRecord,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardImagesServiceABC(ABC):
|
||||
"""High-level service for board-image relationship management."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets images for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardImagesServiceDependencies:
|
||||
"""Service dependencies for the BoardImagesService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardImagesService(BoardImagesServiceABC):
|
||||
_services: BoardImagesServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardImagesServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.add_image_to_board(board_id, image_name)
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
image_records = self._services.board_image_records.get_images_for_board(
|
||||
board_id
|
||||
)
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
board_id,
|
||||
),
|
||||
image_records.items,
|
||||
)
|
||||
)
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=image_records.offset,
|
||||
limit=image_records.limit,
|
||||
total=image_records.total,
|
||||
)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
||||
return board_id
|
||||
|
||||
|
||||
def board_record_to_dto(
|
||||
board_record: BoardRecord, cover_image_name: str | None, image_count: int
|
||||
) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.dict(exclude={'cover_image_name'}),
|
||||
cover_image_name=cover_image_name,
|
||||
image_count=image_count,
|
||||
)
|
329
invokeai/app/services/board_record_storage.py
Normal file
329
invokeai/app/services/board_record_storage.py
Normal file
@ -0,0 +1,329 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import (
|
||||
BoardRecord,
|
||||
deserialize_board_record,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
board_name: Optional[str] = Field(description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's new cover image."
|
||||
)
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
"""Raised when an board record is not found."""
|
||||
|
||||
def __init__(self, message="Board record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordSaveException(Exception):
|
||||
"""Raised when an board record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Board record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordDeleteException(Exception):
|
||||
"""Raised when an board record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Board record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the board record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, board_id: str) -> None:
|
||||
"""Deletes a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
"""Saves a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
"""Gets a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
"""Updates a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets many board records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
"""Gets all board records."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `boards` table and `board_images` junction table."""
|
||||
|
||||
# Create the `boards` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = str(uuid.uuid4())
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
|
||||
# Get the total number of boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](
|
||||
items=boards, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
|
||||
return boards
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
185
invokeai/app/services/boards.py
Normal file
185
invokeai/app/services/boards.py
Normal file
@ -0,0 +1,185 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from logging import Logger
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_images import board_record_to_dto
|
||||
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardChanges,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardServiceABC(ABC):
|
||||
"""High-level service for board management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
"""Creates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardDTO:
|
||||
"""Gets a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
"""Updates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> None:
|
||||
"""Deletes a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
"""Gets many boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardDTO]:
|
||||
"""Gets all boards."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardServiceDependencies:
|
||||
"""Service dependencies for the BoardService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardService(BoardServiceABC):
|
||||
_services: BoardServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.save(board_name)
|
||||
return board_record_to_dto(board_record, None, 0)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
board_record = self._services.board_records.get(board_id)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.update(board_id, changes)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
self._services.board_records.delete(board_id)
|
||||
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self._services.board_records.get_many(offset, limit)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](
|
||||
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
||||
)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
board_records = self._services.board_records.get_all()
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return board_dtos
|
@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC):
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the tables for the `images` database."""
|
||||
"""Creates the `images` table."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
board_id TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
@ -190,7 +197,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = current_timestamp
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
@ -259,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -273,38 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
||||
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
||||
images_query = """--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += f"""AND image_origin = ?\n"""
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
||||
if board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
|
||||
query_params.append(board_id)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
@ -321,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
count = self._cursor.fetchone()[0]
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
@ -412,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_most_recent_image_for_board(
|
||||
self, board_id: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
@ -10,6 +10,7 @@ from invokeai.app.models.image import (
|
||||
InvalidOriginException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
@ -49,7 +50,7 @@ class ImageServiceABC(ABC):
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
intermediate: bool = False,
|
||||
is_intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@ -79,7 +80,7 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str) -> str:
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@ -101,6 +102,7 @@ class ImageServiceABC(ABC):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
@ -114,8 +116,9 @@ class ImageServiceABC(ABC):
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
|
||||
records: ImageRecordStorageBase
|
||||
files: ImageFileStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
image_files: ImageFileStorageBase
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
@ -126,14 +129,16 @@ class ImageServiceDependencies:
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self.records = image_record_storage
|
||||
self.files = image_file_storage
|
||||
self.image_records = image_record_storage
|
||||
self.image_files = image_file_storage
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
@ -144,25 +149,8 @@ class ImageServiceDependencies:
|
||||
class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self._services = ImageServiceDependencies(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=url,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
def __init__(self, services: ImageServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
@ -187,7 +175,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
created_at = self._services.records.save(
|
||||
self._services.image_records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
@ -202,35 +190,15 @@ class ImageService(ImageServiceABC):
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
self._services.image_files.save(
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(image_name, True)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
# Meta fields
|
||||
created_at=created_at,
|
||||
updated_at=created_at, # this is always the same as the created_at at this time
|
||||
deleted_at=None,
|
||||
is_intermediate=is_intermediate,
|
||||
# Extra non-nullable fields for DTO
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
@ -247,7 +215,7 @@ class ImageService(ImageServiceABC):
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.records.update(image_name, changes)
|
||||
self._services.image_records.update(image_name, changes)
|
||||
return self.get_dto(image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
@ -258,7 +226,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_name)
|
||||
return self._services.image_files.get(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
@ -268,7 +236,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_name)
|
||||
return self._services.image_records.get(image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
@ -278,12 +246,13 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_name)
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_name),
|
||||
self._services.urls.get_image_url(image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(image_name),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@ -296,14 +265,14 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_name, thumbnail)
|
||||
return self._services.image_files.get_path(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
return self._services.files.validate_path(path)
|
||||
return self._services.image_files.validate_path(path)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
@ -322,14 +291,16 @@ class ImageService(ImageServiceABC):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
results = self._services.image_records.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
@ -338,6 +309,9 @@ class ImageService(ImageServiceABC):
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(
|
||||
r.image_name
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
@ -355,8 +329,8 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def delete(self, image_name: str):
|
||||
try:
|
||||
self._services.files.delete(image_name)
|
||||
self._services.records.delete(image_name)
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
|
@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.backend import ModelManager
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
@ -26,9 +28,9 @@ class InvocationServices:
|
||||
model_manager: "ModelManager"
|
||||
restoration: "RestorationServices"
|
||||
configuration: "InvokeAISettings"
|
||||
images: "ImageService"
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
images: "ImageServiceABC"
|
||||
boards: "BoardServiceABC"
|
||||
board_images: "BoardImagesServiceABC"
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
@ -39,7 +41,9 @@ class InvocationServices:
|
||||
events: "EventServiceBase",
|
||||
logger: "Logger",
|
||||
latents: "LatentsStorageBase",
|
||||
images: "ImageService",
|
||||
images: "ImageServiceABC",
|
||||
boards: "BoardServiceABC",
|
||||
board_images: "BoardImagesServiceABC",
|
||||
queue: "InvocationQueueABC",
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
@ -52,9 +56,12 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.boards = boards
|
||||
self.board_images = board_images
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
self.boards = boards
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
@ -69,19 +69,6 @@ class ModelManagerServiceBase(ABC):
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns the name and typeof the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
||||
"""Sets the default model to the indicated name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
@ -270,17 +257,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type,
|
||||
)
|
||||
|
||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
return self.mgr.default_model()
|
||||
|
||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
||||
"""Sets the default model to the indicated name."""
|
||||
self.mgr.set_default_model(model_name, base_model, model_type)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
@ -297,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
) -> dict:
|
||||
) -> list[dict]:
|
||||
# ) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
|
62
invokeai/app/services/models/board_record.py
Normal file
62
invokeai/app/services/models/board_record.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class BoardRecord(BaseModel):
|
||||
"""Deserialized board record."""
|
||||
|
||||
board_id: str = Field(description="The unique ID of the board.")
|
||||
"""The unique ID of the board."""
|
||||
board_name: str = Field(description="The name of the board.")
|
||||
"""The name of the board."""
|
||||
created_at: Union[datetime, str] = Field(
|
||||
description="The created timestamp of the board."
|
||||
)
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(
|
||||
description="The updated timestamp of the board."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the board."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the cover image of the board."
|
||||
)
|
||||
"""The name of the cover image of the board."""
|
||||
|
||||
|
||||
class BoardDTO(BoardRecord):
|
||||
"""Deserialized board record with cover image URL and image count."""
|
||||
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's cover image."
|
||||
)
|
||||
"""The URL of the thumbnail of the most recent image in the board."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
"""The number of images in the board."""
|
||||
|
||||
|
||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
"""Deserializes a board record."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
board_id = board_dict.get("board_id", "unknown")
|
||||
board_name = board_dict.get("board_name", "unknown")
|
||||
cover_image_name = board_dict.get("cover_image_name", "unknown")
|
||||
created_at = board_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = board_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
|
||||
|
||||
return BoardRecord(
|
||||
board_id=board_id,
|
||||
board_name=board_name,
|
||||
cover_image_name=cover_image_name,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
)
|
@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend with URLs."""
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Union[str, None] = Field(
|
||||
description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.dict(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -266,6 +266,8 @@ class ModelManager(object):
|
||||
for model_key, model_config in config.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
@ -445,38 +447,6 @@ class ModelManager(object):
|
||||
_cache = self.cache,
|
||||
)
|
||||
|
||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
for model_key, model_config in self.models.items():
|
||||
if model_config.default:
|
||||
return self.parse_key(model_key)
|
||||
|
||||
for model_key, _ in self.models.items():
|
||||
return self.parse_key(model_key)
|
||||
else:
|
||||
return None # TODO: or redo as (None, None, None)
|
||||
|
||||
def set_default_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> None:
|
||||
"""
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_manager.commit()
|
||||
"""
|
||||
|
||||
model_key = self.model_key(model_name, base_model, model_type)
|
||||
if model_key not in self.models:
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
for cur_model_key, config in self.models.items():
|
||||
config.default = cur_model_key == model_key
|
||||
|
||||
def model_info(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -503,9 +473,9 @@ class ModelManager(object):
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a dict of models, in format [base_model][model_type][model_name]
|
||||
Return a list of models.
|
||||
|
||||
Please use model_manager.models() to get all the model names,
|
||||
model_manager.model_info('model-name') to get the stanza for the model
|
||||
@ -513,7 +483,7 @@ class ModelManager(object):
|
||||
object derived from models.yaml
|
||||
"""
|
||||
|
||||
models = dict()
|
||||
models = []
|
||||
for model_key in sorted(self.models, key=str.casefold):
|
||||
model_config = self.models[model_key]
|
||||
|
||||
@ -523,18 +493,16 @@ class ModelManager(object):
|
||||
if model_type is not None and cur_model_type != model_type:
|
||||
continue
|
||||
|
||||
if cur_base_model not in models:
|
||||
models[cur_base_model] = dict()
|
||||
if cur_model_type not in models[cur_base_model]:
|
||||
models[cur_base_model][cur_model_type] = dict()
|
||||
|
||||
models[cur_base_model][cur_model_type][cur_model_name] = dict(
|
||||
model_dict = dict(
|
||||
**model_config.dict(exclude_defaults=True),
|
||||
# OpenAPIModelInfoBase
|
||||
name=cur_model_name,
|
||||
base_model=cur_base_model,
|
||||
type=cur_model_type,
|
||||
)
|
||||
|
||||
models.append(model_dict)
|
||||
|
||||
return models
|
||||
|
||||
def print_models(self) -> None:
|
||||
@ -646,7 +614,9 @@ class ModelManager(object):
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
# alias for config file
|
||||
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
||||
|
||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||
config_file_path = conf_file or self.config_path
|
||||
|
@ -1,3 +1,7 @@
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
@ -29,10 +33,63 @@ MODEL_CLASSES = {
|
||||
#},
|
||||
}
|
||||
|
||||
def get_all_model_configs():
|
||||
configs = set()
|
||||
for models in MODEL_CLASSES.values():
|
||||
for _, model in models.items():
|
||||
configs.update(model._get_configs().values())
|
||||
configs.discard(None)
|
||||
return list(configs) # TODO: set, list or tuple
|
||||
MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
for model_type, model_class in models.items():
|
||||
model_configs = set(model_class._get_configs().values())
|
||||
model_configs.discard(None)
|
||||
MODEL_CONFIGS.extend(model_configs)
|
||||
|
||||
for cfg in model_configs:
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
if openapi_cfg_name in vars():
|
||||
continue
|
||||
|
||||
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||
__annotations__ = dict(
|
||||
type=Literal[model_type.value],
|
||||
),
|
||||
))
|
||||
|
||||
#globals()[openapi_cfg_name] = api_wrapper
|
||||
vars()[openapi_cfg_name] = api_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||
|
||||
def get_model_config_enums():
|
||||
enums = list()
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
fields = inspect.get_annotations(model_config)
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
raise Exception("format field not found")
|
||||
|
||||
# model_format: None
|
||||
# model_format: SomeModelFormat
|
||||
# model_format: Literal[SomeModelFormat.Diffusers]
|
||||
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
|
||||
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
enums.append(field)
|
||||
|
||||
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||
enums.append(type(field.__args__[0]))
|
||||
|
||||
elif field is None:
|
||||
pass
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
||||
|
||||
return enums
|
||||
|
||||
|
@ -48,12 +48,10 @@ class ModelError(str, Enum):
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
#name: str # not included as present in model key
|
||||
description: Optional[str] = Field(None)
|
||||
format: Optional[str] = Field(None)
|
||||
default: Optional[bool] = Field(False)
|
||||
model_format: Optional[str] = Field(None)
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None, exclude=True)
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta):
|
||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||
if len(subtypes) < 2:
|
||||
raise Exception("Invalid subfolder definition!")
|
||||
if all(t is None for t in subtypes):
|
||||
return None
|
||||
elif any(t is None for t in subtypes):
|
||||
raise Exception(f"Unsupported definition: {subtypes}")
|
||||
|
||||
if subtypes[0] in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[subtypes[0]]
|
||||
subtypes = subtypes[1:]
|
||||
@ -122,47 +125,41 @@ class ModelBase(metaclass=ABCMeta):
|
||||
continue
|
||||
|
||||
fields = inspect.get_annotations(value)
|
||||
if "format" not in fields:
|
||||
raise Exception("Invalid config definition - format field not found")
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||
|
||||
format_type = typing.get_origin(fields["format"])
|
||||
if format_type not in {None, Literal, Union}:
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
for model_format in field:
|
||||
configs[model_format.value] = value
|
||||
|
||||
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||
for model_format in field.__args__:
|
||||
configs[model_format.value] = value
|
||||
|
||||
elif field is None:
|
||||
configs[None] = value
|
||||
|
||||
if format_type == Union:
|
||||
f_fields = fields["format"].__args__
|
||||
else:
|
||||
f_fields = (fields["format"],)
|
||||
|
||||
|
||||
for field in f_fields:
|
||||
if field is None:
|
||||
format_name = None
|
||||
else:
|
||||
format_name = field.__args__[0]
|
||||
|
||||
configs[format_name] = value # TODO: error when override(multiple)?
|
||||
|
||||
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
|
||||
|
||||
cls.__configs = configs
|
||||
return cls.__configs
|
||||
|
||||
@classmethod
|
||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||
if "format" not in kwargs:
|
||||
raise Exception("Field 'format' not found in model config")
|
||||
if "model_format" not in kwargs:
|
||||
raise Exception("Field 'model_format' not found in model config")
|
||||
|
||||
configs = cls._get_configs()
|
||||
return configs[kwargs["format"]](**kwargs)
|
||||
return configs[kwargs["model_format"]](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=cls.detect_format(path),
|
||||
model_format=cls.detect_format(path),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
@ -14,12 +15,16 @@ from .base import (
|
||||
classproperty,
|
||||
)
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class ControlNetModel(ModelBase):
|
||||
#model_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
model_format: ControlNetModelFormat
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.ControlNet
|
||||
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return "diffusers"
|
||||
return ControlNetModelFormat.Diffusers
|
||||
else:
|
||||
return "checkpoint"
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) != "diffusers":
|
||||
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
|
||||
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
|
||||
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
||||
else:
|
||||
return model_path
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
ModelBase,
|
||||
@ -12,11 +13,15 @@ from .base import (
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
|
||||
class LoRAModelFormat(str, Enum):
|
||||
LyCORIS = "lycoris"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class LoRAModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||
model_format: LoRAModelFormat # TODO:
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return "diffusers"
|
||||
return LoRAModelFormat.Diffusers
|
||||
else:
|
||||
return "lycoris"
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == "diffusers":
|
||||
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||
# TODO: add diffusers lora when it stabilizes a bit
|
||||
raise NotImplementedError("Diffusers lora not supported")
|
||||
else:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
@ -19,16 +20,19 @@ from .base import (
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
class StableDiffusion1ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == "checkpoint":
|
||||
if model_format == StableDiffusion1ModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == "diffusers":
|
||||
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=model_format,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return "diffusers"
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
else:
|
||||
return "checkpoint"
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
else:
|
||||
return model_path
|
||||
|
||||
class StableDiffusion2ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == "checkpoint":
|
||||
if model_format == StableDiffusion2ModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == "diffusers":
|
||||
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=model_format,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return "diffusers"
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
else:
|
||||
return "checkpoint"
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -281,8 +288,8 @@ def _convert_ckpt_and_cache(
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
|
||||
elif version == BaseModelType.StableDiffusion2:
|
||||
upcast_attention = config.upcast_attention
|
||||
prediction_type = config.prediction_type
|
||||
upcast_attention = model_config.upcast_attention
|
||||
prediction_type = model_config.prediction_type
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown model provided: {version}")
|
||||
|
@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: None
|
||||
model_format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.TextualInversion
|
||||
|
@ -1,5 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
import safetensors
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
class VaeModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class VaeModel(ModelBase):
|
||||
#vae_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
model_format: VaeModelFormat
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Vae
|
||||
@ -70,9 +76,9 @@ class VaeModel(ModelBase):
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return "diffusers"
|
||||
return VaeModelFormat.Diffusers
|
||||
else:
|
||||
return "checkpoint"
|
||||
return VaeModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -82,7 +88,7 @@ class VaeModel(ModelBase):
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) != "diffusers":
|
||||
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||
return _convert_vae_ckpt_and_cache(
|
||||
weights_path=model_path,
|
||||
output_path=output_path,
|
||||
|
@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import { useListModelsQuery } from 'services/apiSlice';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
@ -45,6 +47,18 @@ const App = ({
|
||||
|
||||
const isApplicationReady = useIsApplicationReady();
|
||||
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
model_type: 'pipeline',
|
||||
});
|
||||
const { data: controlnetModels } = useListModelsQuery({
|
||||
model_type: 'controlnet',
|
||||
});
|
||||
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
|
||||
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
|
||||
const { data: embeddingModels } = useListModelsQuery({
|
||||
model_type: 'embedding',
|
||||
});
|
||||
|
||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
@ -143,6 +157,7 @@ const App = ({
|
||||
</Portal>
|
||||
</Grid>
|
||||
<DeleteImageModal />
|
||||
<UpdateImageBoardModal />
|
||||
<Toaster />
|
||||
<GlobalHotkeys />
|
||||
</>
|
||||
|
@ -21,6 +21,8 @@ import {
|
||||
DeleteImageContext,
|
||||
DeleteImageContextProvider,
|
||||
} from 'app/contexts/DeleteImageContext';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
@ -76,11 +78,13 @@ const InvokeAIUI = ({
|
||||
<ThemeLocaleProvider>
|
||||
<ImageDndContext>
|
||||
<DeleteImageContextProvider>
|
||||
<App
|
||||
config={config}
|
||||
headerComponent={headerComponent}
|
||||
setIsReady={setIsReady}
|
||||
/>
|
||||
<AddImageToBoardContextProvider>
|
||||
<App
|
||||
config={config}
|
||||
headerComponent={headerComponent}
|
||||
setIsReady={setIsReady}
|
||||
/>
|
||||
</AddImageToBoardContextProvider>
|
||||
</DeleteImageContextProvider>
|
||||
</ImageDndContext>
|
||||
</ThemeLocaleProvider>
|
||||
|
@ -0,0 +1,89 @@
|
||||
import { useDisclosure } from '@chakra-ui/react';
|
||||
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { useAddImageToBoardMutation } from 'services/apiSlice';
|
||||
|
||||
export type ImageUsage = {
|
||||
isInitialImage: boolean;
|
||||
isCanvasImage: boolean;
|
||||
isNodesImage: boolean;
|
||||
isControlNetImage: boolean;
|
||||
};
|
||||
|
||||
type AddImageToBoardContextValue = {
|
||||
/**
|
||||
* Whether the move image dialog is open.
|
||||
*/
|
||||
isOpen: boolean;
|
||||
/**
|
||||
* Closes the move image dialog.
|
||||
*/
|
||||
onClose: () => void;
|
||||
/**
|
||||
* The image pending movement
|
||||
*/
|
||||
image?: ImageDTO;
|
||||
onClickAddToBoard: (image: ImageDTO) => void;
|
||||
handleAddToBoard: (boardId: string) => void;
|
||||
};
|
||||
|
||||
export const AddImageToBoardContext =
|
||||
createContext<AddImageToBoardContextValue>({
|
||||
isOpen: false,
|
||||
onClose: () => undefined,
|
||||
onClickAddToBoard: () => undefined,
|
||||
handleAddToBoard: () => undefined,
|
||||
});
|
||||
|
||||
type Props = PropsWithChildren;
|
||||
|
||||
export const AddImageToBoardContextProvider = (props: Props) => {
|
||||
const [imageToMove, setImageToMove] = useState<ImageDTO>();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const [addImageToBoard, result] = useAddImageToBoardMutation();
|
||||
|
||||
// Clean up after deleting or dismissing the modal
|
||||
const closeAndClearImageToDelete = useCallback(() => {
|
||||
setImageToMove(undefined);
|
||||
onClose();
|
||||
}, [onClose]);
|
||||
|
||||
const onClickAddToBoard = useCallback(
|
||||
(image?: ImageDTO) => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
setImageToMove(image);
|
||||
onOpen();
|
||||
},
|
||||
[setImageToMove, onOpen]
|
||||
);
|
||||
|
||||
const handleAddToBoard = useCallback(
|
||||
(boardId: string) => {
|
||||
if (imageToMove) {
|
||||
addImageToBoard({
|
||||
board_id: boardId,
|
||||
image_name: imageToMove.image_name,
|
||||
});
|
||||
closeAndClearImageToDelete();
|
||||
}
|
||||
},
|
||||
[addImageToBoard, closeAndClearImageToDelete, imageToMove]
|
||||
);
|
||||
|
||||
return (
|
||||
<AddImageToBoardContext.Provider
|
||||
value={{
|
||||
isOpen,
|
||||
image: imageToMove,
|
||||
onClose: closeAndClearImageToDelete,
|
||||
onClickAddToBoard,
|
||||
handleAddToBoard,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</AddImageToBoardContext.Provider>
|
||||
);
|
||||
};
|
@ -35,25 +35,23 @@ export const selectImageUsage = createSelector(
|
||||
(state: RootState, image_name?: string) => image_name,
|
||||
],
|
||||
(generation, canvas, nodes, controlNet, image_name) => {
|
||||
const isInitialImage = generation.initialImage?.image_name === image_name;
|
||||
const isInitialImage = generation.initialImage?.imageName === image_name;
|
||||
|
||||
const isCanvasImage = canvas.layerState.objects.some(
|
||||
(obj) => obj.kind === 'image' && obj.image.image_name === image_name
|
||||
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
||||
);
|
||||
|
||||
const isNodesImage = nodes.nodes.some((node) => {
|
||||
return some(
|
||||
node.data.inputs,
|
||||
(input) =>
|
||||
input.type === 'image' && input.value?.image_name === image_name
|
||||
(input) => input.type === 'image' && input.value === image_name
|
||||
);
|
||||
});
|
||||
|
||||
const isControlNetImage = some(
|
||||
controlNet.controlNets,
|
||||
(c) =>
|
||||
c.controlImage?.image_name === image_name ||
|
||||
c.processedControlImage?.image_name === image_name
|
||||
c.controlImage === image_name || c.processedControlImage === image_name
|
||||
);
|
||||
|
||||
const imageUsage: ImageUsage = {
|
||||
|
@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
|
||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
|
||||
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
||||
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||
import { omit } from 'lodash-es';
|
||||
@ -18,7 +17,6 @@ const serializationDenylist: {
|
||||
gallery: galleryPersistDenylist,
|
||||
generation: generationPersistDenylist,
|
||||
lightbox: lightboxPersistDenylist,
|
||||
models: modelsPersistDenylist,
|
||||
nodes: nodesPersistDenylist,
|
||||
postprocessing: postprocessingPersistDenylist,
|
||||
system: systemPersistDenylist,
|
||||
|
@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||
import { initialConfigState } from 'features/system/store/configSlice';
|
||||
import { initialModelsState } from 'features/system/store/modelSlice';
|
||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||
@ -21,7 +20,6 @@ const initialStates: {
|
||||
gallery: initialGalleryState,
|
||||
generation: initialGenerationState,
|
||||
lightbox: initialLightboxState,
|
||||
models: initialModelsState,
|
||||
nodes: initialNodesState,
|
||||
postprocessing: initialPostprocessingState,
|
||||
system: initialSystemState,
|
||||
|
@ -73,6 +73,15 @@ import { addImageCategoriesChangedListener } from './listeners/imageCategoriesCh
|
||||
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
|
||||
import {
|
||||
addImageAddedToBoardFulfilledListener,
|
||||
addImageAddedToBoardRejectedListener,
|
||||
} from './listeners/imageAddedToBoard';
|
||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||
import {
|
||||
addImageRemovedFromBoardFulfilledListener,
|
||||
addImageRemovedFromBoardRejectedListener,
|
||||
} from './listeners/imageRemovedFromBoard';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -92,6 +101,12 @@ export type AppListenerEffect = ListenerEffect<
|
||||
AppDispatch
|
||||
>;
|
||||
|
||||
/**
|
||||
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
||||
*
|
||||
* Most side effect logic should live in a listener.
|
||||
*/
|
||||
|
||||
// Image uploaded
|
||||
addImageUploadedFulfilledListener();
|
||||
addImageUploadedRejectedListener();
|
||||
@ -183,3 +198,10 @@ addControlNetAutoProcessListener();
|
||||
|
||||
// Update image URLs on connect
|
||||
addUpdateImageUrlsOnConnectListener();
|
||||
|
||||
// Boards
|
||||
addImageAddedToBoardFulfilledListener();
|
||||
addImageAddedToBoardRejectedListener();
|
||||
addImageRemovedFromBoardFulfilledListener();
|
||||
addImageRemovedFromBoardRejectedListener();
|
||||
addBoardIdSelectedListener();
|
||||
|
@ -0,0 +1,99 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { boardIdSelected } from 'features/gallery/store/boardSlice';
|
||||
import { selectImagesAll } from 'features/gallery/store/imagesSlice';
|
||||
import { IMAGES_PER_PAGE, receivedPageOfImages } from 'services/thunks/image';
|
||||
import { api } from 'services/apiSlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
export const addBoardIdSelectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: boardIdSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const boardId = action.payload;
|
||||
|
||||
// we need to check if we need to fetch more images
|
||||
|
||||
const state = getState();
|
||||
const allImages = selectImagesAll(state);
|
||||
|
||||
if (!boardId) {
|
||||
// a board was unselected
|
||||
dispatch(imageSelected(allImages[0]?.image_name));
|
||||
return;
|
||||
}
|
||||
|
||||
const { categories } = state.images;
|
||||
|
||||
const filteredImages = allImages.filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
|
||||
return isInCategory && isInSelectedBoard;
|
||||
});
|
||||
|
||||
// get the board from the cache
|
||||
const { data: boards } = api.endpoints.listAllBoards.select()(state);
|
||||
const board = boards?.find((b) => b.board_id === boardId);
|
||||
|
||||
if (!board) {
|
||||
// can't find the board in cache...
|
||||
dispatch(imageSelected(allImages[0]?.image_name));
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(imageSelected(board.cover_image_name));
|
||||
|
||||
// if we haven't loaded one full page of images from this board, load more
|
||||
if (
|
||||
filteredImages.length < board.image_count &&
|
||||
filteredImages.length < IMAGES_PER_PAGE
|
||||
) {
|
||||
dispatch(receivedPageOfImages({ categories, boardId }));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addBoardIdSelected_changeSelectedImage_listener = () => {
|
||||
startAppListening({
|
||||
actionCreator: boardIdSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const boardId = action.payload;
|
||||
|
||||
const state = getState();
|
||||
|
||||
// we need to check if we need to fetch more images
|
||||
|
||||
if (!boardId) {
|
||||
// a board was unselected - we don't need to do anything
|
||||
return;
|
||||
}
|
||||
|
||||
const { categories } = state.images;
|
||||
|
||||
const filteredImages = selectImagesAll(state).filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
|
||||
return isInCategory && isInSelectedBoard;
|
||||
});
|
||||
|
||||
// get the board from the cache
|
||||
const { data: boards } = api.endpoints.listAllBoards.select()(state);
|
||||
const board = boards?.find((b) => b.board_id === boardId);
|
||||
if (!board) {
|
||||
// can't find the board in cache...
|
||||
return;
|
||||
}
|
||||
|
||||
// if we haven't loaded one full page of images from this board, load more
|
||||
if (
|
||||
filteredImages.length < board.image_count &&
|
||||
filteredImages.length < IMAGES_PER_PAGE
|
||||
) {
|
||||
dispatch(receivedPageOfImages({ categories, boardId }));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => {
|
||||
[controlNet.processorNode.id]: {
|
||||
...controlNet.processorNode,
|
||||
is_intermediate: true,
|
||||
image: pick(controlNet.controlImage, ['image_name']),
|
||||
image: { image_name: controlNet.controlImage },
|
||||
},
|
||||
},
|
||||
};
|
||||
@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => {
|
||||
dispatch(
|
||||
controlNetProcessedImageChanged({
|
||||
controlNetId,
|
||||
processedControlImage,
|
||||
processedControlImage: processedControlImage.image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
@ -0,0 +1,40 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/thunks/image';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
export const addImageAddedToBoardFulfilledListener = () => {
|
||||
startAppListening({
|
||||
matcher: api.endpoints.addImageToBoard.matchFulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Image added to board'
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
imageName: image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addImageAddedToBoardRejectedListener = () => {
|
||||
startAppListening({
|
||||
matcher: api.endpoints.addImageToBoard.matchRejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Problem adding image to board'
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -12,12 +12,16 @@ export const addImageCategoriesChangedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageCategoriesChanged,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const filteredImagesCount = selectFilteredImagesAsArray(
|
||||
getState()
|
||||
).length;
|
||||
const state = getState();
|
||||
const filteredImagesCount = selectFilteredImagesAsArray(state).length;
|
||||
|
||||
if (!filteredImagesCount) {
|
||||
dispatch(receivedPageOfImages());
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: action.payload,
|
||||
boardId: state.boards.selectedBoardId,
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -6,15 +6,15 @@ import { clamp } from 'lodash-es';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageRemoved,
|
||||
selectImagesEntities,
|
||||
selectImagesIds,
|
||||
} from 'features/gallery/store/imagesSlice';
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
/**
|
||||
* Called when the user requests an image deletion
|
||||
@ -22,7 +22,7 @@ const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||
export const addRequestedImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: requestedImageDeletion,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState, condition }) => {
|
||||
const { image, imageUsage } = action.payload;
|
||||
|
||||
const { image_name } = image;
|
||||
@ -30,9 +30,8 @@ export const addRequestedImageDeletionListener = () => {
|
||||
const state = getState();
|
||||
const selectedImage = state.gallery.selectedImage;
|
||||
|
||||
if (selectedImage && selectedImage.image_name === image_name) {
|
||||
if (selectedImage === image_name) {
|
||||
const ids = selectImagesIds(state);
|
||||
const entities = selectImagesEntities(state);
|
||||
|
||||
const deletedImageIndex = ids.findIndex(
|
||||
(result) => result.toString() === image_name
|
||||
@ -48,10 +47,8 @@ export const addRequestedImageDeletionListener = () => {
|
||||
|
||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||
|
||||
const newSelectedImage = entities[newSelectedImageId];
|
||||
|
||||
if (newSelectedImageId) {
|
||||
dispatch(imageSelected(newSelectedImage));
|
||||
dispatch(imageSelected(newSelectedImageId as string));
|
||||
} else {
|
||||
dispatch(imageSelected());
|
||||
}
|
||||
@ -79,7 +76,21 @@ export const addRequestedImageDeletionListener = () => {
|
||||
dispatch(imageRemoved(image_name));
|
||||
|
||||
// Delete from server
|
||||
dispatch(imageDeleted({ imageName: image_name }));
|
||||
const { requestId } = dispatch(imageDeleted({ imageName: image_name }));
|
||||
|
||||
// Wait for successful deletion, then trigger boards to re-fetch
|
||||
const wasImageDeleted = await condition(
|
||||
(action): action is ReturnType<typeof imageDeleted.fulfilled> =>
|
||||
imageDeleted.fulfilled.match(action) &&
|
||||
action.meta.requestId === requestId,
|
||||
30000
|
||||
);
|
||||
|
||||
if (wasImageDeleted) {
|
||||
dispatch(
|
||||
api.util.invalidateTags([{ type: 'Board', id: image.board_id }])
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -0,0 +1,40 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/thunks/image';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
export const addImageRemovedFromBoardFulfilledListener = () => {
|
||||
startAppListening({
|
||||
matcher: api.endpoints.removeImageFromBoard.matchFulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Image added to board'
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
imageName: image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addImageRemovedFromBoardRejectedListener = () => {
|
||||
startAppListening({
|
||||
matcher: api.endpoints.removeImageFromBoard.matchRejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Problem adding image to board'
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -46,7 +46,12 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
|
||||
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
|
||||
const { controlNetId } = postUploadAction;
|
||||
dispatch(controlNetImageChanged({ controlNetId, controlImage: image }));
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlNetId,
|
||||
controlImage: image.image_name,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
@ -15,16 +14,17 @@ export const addSocketConnectedEventListener = () => {
|
||||
|
||||
moduleLog.debug({ timestamp }, 'Connected');
|
||||
|
||||
const { models, nodes, config, images } = getState();
|
||||
const { nodes, config, images } = getState();
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
if (!images.ids.length) {
|
||||
dispatch(receivedPageOfImages());
|
||||
}
|
||||
|
||||
if (!models.ids.length) {
|
||||
dispatch(receivedModels());
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['general'],
|
||||
isIntermediate: false,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||
|
@ -9,6 +9,7 @@ import { imageMetadataReceived } from 'services/thunks/image';
|
||||
import { sessionCanceled } from 'services/thunks/session';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
const nodeDenylist = ['dataURL_image'];
|
||||
@ -24,7 +25,8 @@ export const addInvocationCompleteEventListener = () => {
|
||||
|
||||
const sessionId = action.payload.data.graph_execution_state_id;
|
||||
|
||||
const { cancelType, isCancelScheduled } = getState().system;
|
||||
const { cancelType, isCancelScheduled, boardIdToAddTo } =
|
||||
getState().system;
|
||||
|
||||
// Handle scheduled cancelation
|
||||
if (cancelType === 'scheduled' && isCancelScheduled) {
|
||||
@ -57,6 +59,15 @@ export const addInvocationCompleteEventListener = () => {
|
||||
dispatch(addImageToStagingArea(imageDTO));
|
||||
}
|
||||
|
||||
if (boardIdToAddTo && !imageDTO.is_intermediate) {
|
||||
dispatch(
|
||||
api.endpoints.addImageToBoard.initiate({
|
||||
board_id: boardIdToAddTo,
|
||||
image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(progressImageSet(null));
|
||||
}
|
||||
// pass along the socket event as an application action
|
||||
|
@ -22,15 +22,15 @@ const selectAllUsedImages = createSelector(
|
||||
selectImagesEntities,
|
||||
],
|
||||
(generation, canvas, nodes, controlNet, imageEntities) => {
|
||||
const allUsedImages: ImageDTO[] = [];
|
||||
const allUsedImages: string[] = [];
|
||||
|
||||
if (generation.initialImage) {
|
||||
allUsedImages.push(generation.initialImage);
|
||||
allUsedImages.push(generation.initialImage.imageName);
|
||||
}
|
||||
|
||||
canvas.layerState.objects.forEach((obj) => {
|
||||
if (obj.kind === 'image') {
|
||||
allUsedImages.push(obj.image);
|
||||
allUsedImages.push(obj.imageName);
|
||||
}
|
||||
});
|
||||
|
||||
@ -53,7 +53,7 @@ const selectAllUsedImages = createSelector(
|
||||
|
||||
forEach(imageEntities, (image) => {
|
||||
if (image) {
|
||||
allUsedImages.push(image);
|
||||
allUsedImages.push(image.image_name);
|
||||
}
|
||||
});
|
||||
|
||||
@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => {
|
||||
`Fetching new image URLs for ${allUsedImages.length} images`
|
||||
);
|
||||
|
||||
allUsedImages.forEach(({ image_name }) => {
|
||||
allUsedImages.forEach((image_name) => {
|
||||
dispatch(
|
||||
imageUrlsReceived({
|
||||
imageName: image_name,
|
||||
|
@ -5,40 +5,39 @@ import {
|
||||
configureStore,
|
||||
} from '@reduxjs/toolkit';
|
||||
|
||||
import { rememberReducer, rememberEnhancer } from 'redux-remember';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import imagesReducer from 'features/gallery/store/imagesSlice';
|
||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
// import sessionReducer from 'features/system/store/sessionSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import modelsReducer from 'features/system/store/modelSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
|
||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const allReducers = {
|
||||
canvas: canvasReducer,
|
||||
gallery: galleryReducer,
|
||||
generation: generationReducer,
|
||||
lightbox: lightboxReducer,
|
||||
models: modelsReducer,
|
||||
nodes: nodesReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
system: systemReducer,
|
||||
@ -47,7 +46,9 @@ const allReducers = {
|
||||
hotkeys: hotkeysReducer,
|
||||
images: imagesReducer,
|
||||
controlNet: controlNetReducer,
|
||||
boards: boardsReducer,
|
||||
// session: sessionReducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
@ -59,12 +60,12 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'gallery',
|
||||
'generation',
|
||||
'lightbox',
|
||||
// 'models',
|
||||
'nodes',
|
||||
'postprocessing',
|
||||
'system',
|
||||
'ui',
|
||||
'controlNet',
|
||||
// 'boards',
|
||||
// 'hotkeys',
|
||||
// 'config',
|
||||
];
|
||||
@ -84,6 +85,7 @@ export const store = configureStore({
|
||||
immutableCheck: false,
|
||||
serializableCheck: false,
|
||||
})
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
devTools: {
|
||||
|
@ -9,7 +9,7 @@ import {
|
||||
import { useDraggable, useDroppable } from '@dnd-kit/core';
|
||||
import { useCombinedRefs } from '@dnd-kit/utilities';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import { ReactElement, SyntheticEvent, useCallback } from 'react';
|
||||
@ -53,7 +53,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
isDropDisabled = false,
|
||||
isDragDisabled = false,
|
||||
isUploadDisabled = false,
|
||||
fallback = <IAIImageFallback />,
|
||||
fallback = <IAIImageLoadingFallback />,
|
||||
payloadImage,
|
||||
minSize = 24,
|
||||
postUploadAction,
|
||||
|
@ -1,10 +1,20 @@
|
||||
import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react';
|
||||
import {
|
||||
As,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Icon,
|
||||
IconProps,
|
||||
Spinner,
|
||||
SpinnerProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { ReactElement } from 'react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
|
||||
type Props = FlexProps & {
|
||||
spinnerProps?: SpinnerProps;
|
||||
};
|
||||
|
||||
export const IAIImageFallback = (props: Props) => {
|
||||
export const IAIImageLoadingFallback = (props: Props) => {
|
||||
const { spinnerProps, ...rest } = props;
|
||||
const { sx, ...restFlexProps } = rest;
|
||||
return (
|
||||
@ -25,3 +35,35 @@ export const IAIImageFallback = (props: Props) => {
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type IAINoImageFallbackProps = {
|
||||
flexProps?: FlexProps;
|
||||
iconProps?: IconProps;
|
||||
as?: As;
|
||||
};
|
||||
|
||||
export const IAINoImageFallback = (props: IAINoImageFallbackProps) => {
|
||||
const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} };
|
||||
const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} };
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
bg: 'base.900',
|
||||
opacity: 0.7,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
...flexSx,
|
||||
}}
|
||||
{...restFlexProps}
|
||||
>
|
||||
<Icon
|
||||
as={props.as ?? FaImage}
|
||||
sx={{ color: 'base.700', ...iconSx }}
|
||||
{...restIconProps}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -1,14 +1,21 @@
|
||||
import { Image } from 'react-konva';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { Image, Rect } from 'react-konva';
|
||||
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||
import useImage from 'use-image';
|
||||
import { CanvasImage } from '../store/canvasTypes';
|
||||
|
||||
type IAICanvasImageProps = {
|
||||
url: string;
|
||||
x: number;
|
||||
y: number;
|
||||
canvasImage: CanvasImage;
|
||||
};
|
||||
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
||||
const { url, x, y } = props;
|
||||
const [image] = useImage(url, 'anonymous');
|
||||
const { width, height, x, y, imageName } = props.canvasImage;
|
||||
const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
|
||||
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
|
||||
|
||||
if (!imageDTO) {
|
||||
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
|
||||
}
|
||||
|
||||
return <Image x={x} y={y} image={image} listening={false} />;
|
||||
};
|
||||
|
||||
|
@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => {
|
||||
<Group name="outpainting-objects" listening={false}>
|
||||
{objects.map((obj, i) => {
|
||||
if (isCanvasBaseImage(obj)) {
|
||||
return (
|
||||
<IAICanvasImage
|
||||
key={i}
|
||||
x={obj.x}
|
||||
y={obj.y}
|
||||
url={obj.image.image_url}
|
||||
/>
|
||||
);
|
||||
return <IAICanvasImage key={i} canvasImage={obj} />;
|
||||
} else if (isCanvasBaseLine(obj)) {
|
||||
const line = (
|
||||
<Line
|
||||
|
@ -59,11 +59,7 @@ const IAICanvasStagingArea = (props: Props) => {
|
||||
return (
|
||||
<Group {...rest}>
|
||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||
<IAICanvasImage
|
||||
url={currentStagingAreaImage.image.image_url}
|
||||
x={x}
|
||||
y={y}
|
||||
/>
|
||||
<IAICanvasImage canvasImage={currentStagingAreaImage} />
|
||||
)}
|
||||
{shouldShowStagingOutline && (
|
||||
<Group>
|
||||
|
@ -203,7 +203,7 @@ export const canvasSlice = createSlice({
|
||||
y: 0,
|
||||
width: width,
|
||||
height: height,
|
||||
image: image,
|
||||
imageName: image.image_name,
|
||||
},
|
||||
],
|
||||
};
|
||||
@ -325,7 +325,7 @@ export const canvasSlice = createSlice({
|
||||
kind: 'image',
|
||||
layer: 'base',
|
||||
...state.layerState.stagingArea.boundingBox,
|
||||
image,
|
||||
imageName: image.image_name,
|
||||
});
|
||||
|
||||
state.layerState.stagingArea.selectedImageIndex =
|
||||
@ -865,25 +865,25 @@ export const canvasSlice = createSlice({
|
||||
state.doesCanvasNeedScaling = true;
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
state.layerState.objects.forEach((object) => {
|
||||
if (object.kind === 'image') {
|
||||
if (object.image.image_name === image_name) {
|
||||
object.image.image_url = image_url;
|
||||
object.image.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
}
|
||||
});
|
||||
// state.layerState.objects.forEach((object) => {
|
||||
// if (object.kind === 'image') {
|
||||
// if (object.image.image_name === image_name) {
|
||||
// object.image.image_url = image_url;
|
||||
// object.image.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
|
||||
state.layerState.stagingArea.images.forEach((stagedImage) => {
|
||||
if (stagedImage.image.image_name === image_name) {
|
||||
stagedImage.image.image_url = image_url;
|
||||
stagedImage.image.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
});
|
||||
});
|
||||
// state.layerState.stagingArea.images.forEach((stagedImage) => {
|
||||
// if (stagedImage.image.image_name === image_name) {
|
||||
// stagedImage.image.image_url = image_url;
|
||||
// stagedImage.image.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// });
|
||||
// });
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -38,7 +38,7 @@ export type CanvasImage = {
|
||||
y: number;
|
||||
width: number;
|
||||
height: number;
|
||||
image: ImageDTO;
|
||||
imageName: string;
|
||||
};
|
||||
|
||||
export type CanvasMaskLine = {
|
||||
|
@ -11,9 +11,11 @@ import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
|
||||
const selector = createSelector(
|
||||
controlNetSelector,
|
||||
@ -31,24 +33,45 @@ type Props = {
|
||||
|
||||
const ControlNetImagePreview = (props: Props) => {
|
||||
const { imageSx } = props;
|
||||
const { controlNetId, controlImage, processedControlImage, processorType } =
|
||||
props.controlNet;
|
||||
const {
|
||||
controlNetId,
|
||||
controlImage: controlImageName,
|
||||
processedControlImage: processedControlImageName,
|
||||
processorType,
|
||||
} = props.controlNet;
|
||||
const dispatch = useAppDispatch();
|
||||
const { pendingControlImages } = useAppSelector(selector);
|
||||
|
||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||
|
||||
const {
|
||||
data: controlImage,
|
||||
isLoading: isLoadingControlImage,
|
||||
isError: isErrorControlImage,
|
||||
isSuccess: isSuccessControlImage,
|
||||
} = useGetImageDTOQuery(controlImageName ?? skipToken);
|
||||
|
||||
const {
|
||||
data: processedControlImage,
|
||||
isLoading: isLoadingProcessedControlImage,
|
||||
isError: isErrorProcessedControlImage,
|
||||
isSuccess: isSuccessProcessedControlImage,
|
||||
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (controlImage?.image_name === droppedImage.image_name) {
|
||||
if (controlImageName === droppedImage.image_name) {
|
||||
return;
|
||||
}
|
||||
setIsMouseOverImage(false);
|
||||
dispatch(
|
||||
controlNetImageChanged({ controlNetId, controlImage: droppedImage })
|
||||
controlNetImageChanged({
|
||||
controlNetId,
|
||||
controlImage: droppedImage.image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
[controlImage, controlNetId, dispatch]
|
||||
[controlImageName, controlNetId, dispatch]
|
||||
);
|
||||
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
@ -150,7 +173,7 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<IAIImageFallback />
|
||||
<IAIImageLoadingFallback />
|
||||
</Box>
|
||||
)}
|
||||
{controlImage && (
|
||||
|
@ -39,8 +39,8 @@ export type ControlNetConfig = {
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
controlImage: ImageDTO | null;
|
||||
processedControlImage: ImageDTO | null;
|
||||
controlImage: string | null;
|
||||
processedControlImage: string | null;
|
||||
processorType: ControlNetProcessorType;
|
||||
processorNode: RequiredControlNetProcessorNode;
|
||||
shouldAutoConfig: boolean;
|
||||
@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({
|
||||
},
|
||||
controlNetAddedFromImage: (
|
||||
state,
|
||||
action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
|
||||
action: PayloadAction<{ controlNetId: string; controlImage: string }>
|
||||
) => {
|
||||
const { controlNetId, controlImage } = action.payload;
|
||||
state.controlNets[controlNetId] = {
|
||||
@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
controlImage: ImageDTO | null;
|
||||
controlImage: string | null;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, controlImage } = action.payload;
|
||||
@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
processedControlImage: ImageDTO | null;
|
||||
processedControlImage: string | null;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, processedControlImage } = action.payload;
|
||||
@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({
|
||||
// Preemptively remove the image from the gallery
|
||||
const { imageName } = action.meta.arg;
|
||||
forEach(state.controlNets, (c) => {
|
||||
if (c.controlImage?.image_name === imageName) {
|
||||
if (c.controlImage === imageName) {
|
||||
c.controlImage = null;
|
||||
c.processedControlImage = null;
|
||||
}
|
||||
if (c.processedControlImage?.image_name === imageName) {
|
||||
if (c.processedControlImage === imageName) {
|
||||
c.processedControlImage = null;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
forEach(state.controlNets, (c) => {
|
||||
if (c.controlImage?.image_name === image_name) {
|
||||
c.controlImage.image_url = image_url;
|
||||
c.controlImage.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
if (c.processedControlImage?.image_name === image_name) {
|
||||
c.processedControlImage.image_url = image_url;
|
||||
c.processedControlImage.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
});
|
||||
});
|
||||
// forEach(state.controlNets, (c) => {
|
||||
// if (c.controlImage?.image_name === image_name) {
|
||||
// c.controlImage.image_url = image_url;
|
||||
// c.controlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// if (c.processedControlImage?.image_name === image_name) {
|
||||
// c.processedControlImage.image_url = image_url;
|
||||
// c.processedControlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// });
|
||||
// });
|
||||
|
||||
builder.addCase(appSocketInvocationError, (state, action) => {
|
||||
state.pendingControlImages = [];
|
||||
|
@ -0,0 +1,27 @@
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { useCallback } from 'react';
|
||||
import { useCreateBoardMutation } from 'services/apiSlice';
|
||||
|
||||
const DEFAULT_BOARD_NAME = 'My Board';
|
||||
|
||||
const AddBoardButton = () => {
|
||||
const [createBoard, { isLoading }] = useCreateBoardMutation();
|
||||
|
||||
const handleCreateBoard = useCallback(() => {
|
||||
createBoard(DEFAULT_BOARD_NAME);
|
||||
}, [createBoard]);
|
||||
|
||||
return (
|
||||
<IAIButton
|
||||
isLoading={isLoading}
|
||||
aria-label="Add Board"
|
||||
onClick={handleCreateBoard}
|
||||
size="sm"
|
||||
sx={{ px: 4 }}
|
||||
>
|
||||
Add Board
|
||||
</IAIButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default AddBoardButton;
|
@ -0,0 +1,93 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { FaImages } from 'react-icons/fa';
|
||||
import { boardIdSelected } from '../../store/boardSlice';
|
||||
import { useDispatch } from 'react-redux';
|
||||
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import { SelectedItemOverlay } from '../SelectedItemOverlay';
|
||||
import { useCallback } from 'react';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
|
||||
import { useDroppable } from '@dnd-kit/core';
|
||||
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||
|
||||
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const handleAllImagesBoardClick = () => {
|
||||
dispatch(boardIdSelected());
|
||||
};
|
||||
|
||||
const [removeImageFromBoard, { isLoading }] =
|
||||
useRemoveImageFromBoardMutation();
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (!droppedImage.board_id) {
|
||||
return;
|
||||
}
|
||||
removeImageFromBoard({
|
||||
board_id: droppedImage.board_id,
|
||||
image_name: droppedImage.image_name,
|
||||
});
|
||||
},
|
||||
[removeImageFromBoard]
|
||||
);
|
||||
|
||||
const {
|
||||
isOver,
|
||||
setNodeRef,
|
||||
active: isDropActive,
|
||||
} = useDroppable({
|
||||
id: `board_droppable_all_images`,
|
||||
data: {
|
||||
handleDrop,
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
cursor: 'pointer',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
onClick={handleAllImagesBoardClick}
|
||||
>
|
||||
<Flex
|
||||
ref={setNodeRef}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaImages} />
|
||||
<AnimatePresence>
|
||||
{isSelected && <SelectedItemOverlay />}
|
||||
</AnimatePresence>
|
||||
<AnimatePresence>
|
||||
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||
</AnimatePresence>
|
||||
</Flex>
|
||||
<Text
|
||||
sx={{
|
||||
color: isSelected ? 'base.50' : 'base.200',
|
||||
fontWeight: isSelected ? 600 : undefined,
|
||||
fontSize: 'xs',
|
||||
}}
|
||||
>
|
||||
All Images
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default AllImagesBoard;
|
@ -0,0 +1,134 @@
|
||||
import {
|
||||
Collapse,
|
||||
Flex,
|
||||
Grid,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import {
|
||||
boardsSelector,
|
||||
setBoardSearchText,
|
||||
} from 'features/gallery/store/boardSlice';
|
||||
import { memo, useState } from 'react';
|
||||
import HoverableBoard from './HoverableBoard';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import AddBoardButton from './AddBoardButton';
|
||||
import AllImagesBoard from './AllImagesBoard';
|
||||
import { CloseIcon } from '@chakra-ui/icons';
|
||||
import { useListAllBoardsQuery } from 'services/apiSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
[boardsSelector],
|
||||
(boardsState) => {
|
||||
const { selectedBoardId, searchText } = boardsState;
|
||||
return { selectedBoardId, searchText };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
type Props = {
|
||||
isOpen: boolean;
|
||||
};
|
||||
|
||||
const BoardsList = (props: Props) => {
|
||||
const { isOpen } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { selectedBoardId, searchText } = useAppSelector(selector);
|
||||
|
||||
const { data: boards } = useListAllBoardsQuery();
|
||||
|
||||
const filteredBoards = searchText
|
||||
? boards?.filter((board) =>
|
||||
board.board_name.toLowerCase().includes(searchText.toLowerCase())
|
||||
)
|
||||
: boards;
|
||||
|
||||
const [searchMode, setSearchMode] = useState(false);
|
||||
|
||||
const handleBoardSearch = (searchTerm: string) => {
|
||||
setSearchMode(searchTerm.length > 0);
|
||||
dispatch(setBoardSearchText(searchTerm));
|
||||
};
|
||||
const clearBoardSearch = () => {
|
||||
setSearchMode(false);
|
||||
dispatch(setBoardSearchText(''));
|
||||
};
|
||||
|
||||
return (
|
||||
<Collapse in={isOpen} animateOpacity>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
bg: 'base.800',
|
||||
borderRadius: 'base',
|
||||
p: 2,
|
||||
mt: 2,
|
||||
}}
|
||||
>
|
||||
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||
<InputGroup>
|
||||
<Input
|
||||
placeholder="Search Boards..."
|
||||
value={searchText}
|
||||
onChange={(e) => {
|
||||
handleBoardSearch(e.target.value);
|
||||
}}
|
||||
/>
|
||||
{searchText && searchText.length && (
|
||||
<InputRightElement>
|
||||
<IconButton
|
||||
onClick={clearBoardSearch}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
aria-label="Clear Search"
|
||||
icon={<CloseIcon boxSize={3} />}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
<AddBoardButton />
|
||||
</Flex>
|
||||
<OverlayScrollbarsComponent
|
||||
defer
|
||||
style={{ height: '100%', width: '100%' }}
|
||||
options={{
|
||||
scrollbars: {
|
||||
visibility: 'auto',
|
||||
autoHide: 'move',
|
||||
autoHideDelay: 1300,
|
||||
theme: 'os-theme-dark',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Grid
|
||||
className="list-container"
|
||||
sx={{
|
||||
gap: 2,
|
||||
gridTemplateRows: '5.5rem 5.5rem',
|
||||
gridAutoFlow: 'column dense',
|
||||
gridAutoColumns: '4rem',
|
||||
}}
|
||||
>
|
||||
{!searchMode && <AllImagesBoard isSelected={!selectedBoardId} />}
|
||||
{filteredBoards &&
|
||||
filteredBoards.map((board) => (
|
||||
<HoverableBoard
|
||||
key={board.board_id}
|
||||
board={board}
|
||||
isSelected={selectedBoardId === board.board_id}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Flex>
|
||||
</Collapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BoardsList);
|
@ -0,0 +1,193 @@
|
||||
import {
|
||||
Badge,
|
||||
Box,
|
||||
Editable,
|
||||
EditableInput,
|
||||
EditablePreview,
|
||||
Flex,
|
||||
Image,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaFolder, FaTrash } from 'react-icons/fa';
|
||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||
import { BoardDTO, ImageDTO } from 'services/api';
|
||||
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { boardIdSelected } from 'features/gallery/store/boardSlice';
|
||||
import {
|
||||
useAddImageToBoardMutation,
|
||||
useDeleteBoardMutation,
|
||||
useGetImageDTOQuery,
|
||||
useUpdateBoardMutation,
|
||||
} from 'services/apiSlice';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { useDroppable } from '@dnd-kit/core';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||
import { SelectedItemOverlay } from '../SelectedItemOverlay';
|
||||
|
||||
interface HoverableBoardProps {
|
||||
board: BoardDTO;
|
||||
isSelected: boolean;
|
||||
}
|
||||
|
||||
const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data: coverImage } = useGetImageDTOQuery(
|
||||
board.cover_image_name ?? skipToken
|
||||
);
|
||||
|
||||
const { board_name, board_id } = board;
|
||||
|
||||
const handleSelectBoard = useCallback(() => {
|
||||
dispatch(boardIdSelected(board_id));
|
||||
}, [board_id, dispatch]);
|
||||
|
||||
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
|
||||
useUpdateBoardMutation();
|
||||
|
||||
const [deleteBoard, { isLoading: isDeleteBoardLoading }] =
|
||||
useDeleteBoardMutation();
|
||||
|
||||
const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] =
|
||||
useAddImageToBoardMutation();
|
||||
|
||||
const handleUpdateBoardName = (newBoardName: string) => {
|
||||
updateBoard({ board_id, changes: { board_name: newBoardName } });
|
||||
};
|
||||
|
||||
const handleDeleteBoard = useCallback(() => {
|
||||
deleteBoard(board_id);
|
||||
}, [board_id, deleteBoard]);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (droppedImage.board_id === board_id) {
|
||||
return;
|
||||
}
|
||||
addImageToBoard({ board_id, image_name: droppedImage.image_name });
|
||||
},
|
||||
[addImageToBoard, board_id]
|
||||
);
|
||||
|
||||
const {
|
||||
isOver,
|
||||
setNodeRef,
|
||||
active: isDropActive,
|
||||
} = useDroppable({
|
||||
id: `board_droppable_${board_id}`,
|
||||
data: {
|
||||
handleDrop,
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<Box sx={{ touchAction: 'none' }}>
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
renderMenu={() => (
|
||||
<MenuList sx={{ visibility: 'visible !important' }}>
|
||||
<MenuItem
|
||||
sx={{ color: 'error.300' }}
|
||||
icon={<FaTrash />}
|
||||
onClickCapture={handleDeleteBoard}
|
||||
>
|
||||
Delete Board
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
)}
|
||||
>
|
||||
{(ref) => (
|
||||
<Flex
|
||||
key={board_id}
|
||||
userSelect="none"
|
||||
ref={ref}
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
cursor: 'pointer',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
ref={setNodeRef}
|
||||
onClick={handleSelectBoard}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
aspectRatio: '1/1',
|
||||
overflow: 'hidden',
|
||||
}}
|
||||
>
|
||||
{board.cover_image_name && coverImage?.image_url && (
|
||||
<Image src={coverImage?.image_url} draggable={false} />
|
||||
)}
|
||||
{!(board.cover_image_name && coverImage?.image_url) && (
|
||||
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaFolder} />
|
||||
)}
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
insetInlineEnd: 0,
|
||||
top: 0,
|
||||
p: 1,
|
||||
}}
|
||||
>
|
||||
<Badge variant="solid">{board.image_count}</Badge>
|
||||
</Flex>
|
||||
<AnimatePresence>
|
||||
{isSelected && <SelectedItemOverlay />}
|
||||
</AnimatePresence>
|
||||
<AnimatePresence>
|
||||
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||
</AnimatePresence>
|
||||
</Flex>
|
||||
|
||||
<Box sx={{ width: 'full' }}>
|
||||
<Editable
|
||||
defaultValue={board_name}
|
||||
submitOnBlur={false}
|
||||
onSubmit={(nextValue) => {
|
||||
handleUpdateBoardName(nextValue);
|
||||
}}
|
||||
>
|
||||
<EditablePreview
|
||||
sx={{
|
||||
color: isSelected ? 'base.50' : 'base.200',
|
||||
fontWeight: isSelected ? 600 : undefined,
|
||||
fontSize: 'xs',
|
||||
textAlign: 'center',
|
||||
p: 0,
|
||||
}}
|
||||
noOfLines={1}
|
||||
/>
|
||||
<EditableInput
|
||||
sx={{
|
||||
color: 'base.50',
|
||||
fontSize: 'xs',
|
||||
borderColor: 'base.500',
|
||||
p: 0,
|
||||
outline: 0,
|
||||
}}
|
||||
/>
|
||||
</Editable>
|
||||
</Box>
|
||||
</Flex>
|
||||
)}
|
||||
</ContextMenu>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
HoverableBoard.displayName = 'HoverableBoard';
|
||||
|
||||
export default HoverableBoard;
|
@ -0,0 +1,93 @@
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogContent,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogOverlay,
|
||||
Box,
|
||||
Flex,
|
||||
Spinner,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
import { memo, useContext, useRef, useState } from 'react';
|
||||
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { useListAllBoardsQuery } from 'services/apiSlice';
|
||||
|
||||
const UpdateImageBoardModal = () => {
|
||||
// const boards = useSelector(selectBoardsAll);
|
||||
const { data: boards, isFetching } = useListAllBoardsQuery();
|
||||
const { isOpen, onClose, handleAddToBoard, image } = useContext(
|
||||
AddImageToBoardContext
|
||||
);
|
||||
const [selectedBoard, setSelectedBoard] = useState<string | null>();
|
||||
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const currentBoard = boards?.find(
|
||||
(board) => board.board_id === image?.board_id
|
||||
);
|
||||
|
||||
return (
|
||||
<AlertDialog
|
||||
isOpen={isOpen}
|
||||
leastDestructiveRef={cancelRef}
|
||||
onClose={onClose}
|
||||
isCentered
|
||||
>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
{currentBoard ? 'Move Image to Board' : 'Add Image to Board'}
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
<Box>
|
||||
<Flex direction="column" gap={3}>
|
||||
{currentBoard && (
|
||||
<Text>
|
||||
Moving this image from{' '}
|
||||
<strong>{currentBoard.board_name}</strong> to
|
||||
</Text>
|
||||
)}
|
||||
{isFetching ? (
|
||||
<Spinner />
|
||||
) : (
|
||||
<IAIMantineSelect
|
||||
placeholder="Select Board"
|
||||
onChange={(v) => setSelectedBoard(v)}
|
||||
value={selectedBoard}
|
||||
data={(boards ?? []).map((board) => ({
|
||||
label: board.board_name,
|
||||
value: board.board_id,
|
||||
}))}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
</AlertDialogBody>
|
||||
<AlertDialogFooter>
|
||||
<IAIButton onClick={onClose}>Cancel</IAIButton>
|
||||
<IAIButton
|
||||
isDisabled={!selectedBoard}
|
||||
colorScheme="accent"
|
||||
onClick={() => {
|
||||
if (selectedBoard) {
|
||||
handleAddToBoard(selectedBoard);
|
||||
}
|
||||
}}
|
||||
ml={3}
|
||||
>
|
||||
{currentBoard ? 'Move' : 'Add'}
|
||||
</IAIButton>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(UpdateImageBoardModal);
|
@ -51,9 +51,12 @@ import { useAppToaster } from 'app/components/Toaster';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
||||
import { DeleteImageButton } from './DeleteImageModal';
|
||||
import { selectImagesById } from '../store/imagesSlice';
|
||||
import { RootState } from 'app/store/store';
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
[
|
||||
(state: RootState) => state,
|
||||
systemSelector,
|
||||
gallerySelector,
|
||||
postprocessingSelector,
|
||||
@ -61,7 +64,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
lightboxSelector,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(system, gallery, postprocessing, ui, lightbox, activeTabName) => {
|
||||
(state, system, gallery, postprocessing, ui, lightbox, activeTabName) => {
|
||||
const {
|
||||
isProcessing,
|
||||
isConnected,
|
||||
@ -81,6 +84,8 @@ const currentImageButtonsSelector = createSelector(
|
||||
shouldShowProgressInViewer,
|
||||
} = ui;
|
||||
|
||||
const imageDTO = selectImagesById(state, gallery.selectedImage ?? '');
|
||||
|
||||
const { selectedImage } = gallery;
|
||||
|
||||
return {
|
||||
@ -97,10 +102,10 @@ const currentImageButtonsSelector = createSelector(
|
||||
activeTabName,
|
||||
isLightboxOpen,
|
||||
shouldHidePreview,
|
||||
image: selectedImage,
|
||||
seed: selectedImage?.metadata?.seed,
|
||||
prompt: selectedImage?.metadata?.positive_conditioning,
|
||||
negativePrompt: selectedImage?.metadata?.negative_conditioning,
|
||||
image: imageDTO,
|
||||
seed: imageDTO?.metadata?.seed,
|
||||
prompt: imageDTO?.metadata?.positive_conditioning,
|
||||
negativePrompt: imageDTO?.metadata?.negative_conditioning,
|
||||
shouldShowProgressInViewer,
|
||||
};
|
||||
},
|
||||
|
@ -9,12 +9,12 @@ import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { configSelector } from '../../system/store/configSelectors';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { imageSelected } from '../store/gallerySlice';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
|
||||
export const imagesSelector = createSelector(
|
||||
[uiSelector, gallerySelector, systemSelector],
|
||||
@ -29,7 +29,7 @@ export const imagesSelector = createSelector(
|
||||
return {
|
||||
shouldShowImageDetails,
|
||||
shouldHidePreview,
|
||||
image: selectedImage,
|
||||
selectedImage,
|
||||
progressImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldAntialiasProgressImage,
|
||||
@ -45,11 +45,23 @@ export const imagesSelector = createSelector(
|
||||
const CurrentImagePreview = () => {
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
image,
|
||||
selectedImage,
|
||||
progressImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldAntialiasProgressImage,
|
||||
} = useAppSelector(imagesSelector);
|
||||
|
||||
// const image = useAppSelector((state: RootState) =>
|
||||
// selectImagesById(state, selectedImage ?? '')
|
||||
// );
|
||||
|
||||
const {
|
||||
data: image,
|
||||
isLoading,
|
||||
isError,
|
||||
isSuccess,
|
||||
} = useGetImageDTOQuery(selectedImage ?? skipToken);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleDrop = useCallback(
|
||||
@ -57,7 +69,7 @@ const CurrentImagePreview = () => {
|
||||
if (droppedImage.image_name === image?.image_name) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageSelected(droppedImage));
|
||||
dispatch(imageSelected(droppedImage.image_name));
|
||||
},
|
||||
[dispatch, image?.image_name]
|
||||
);
|
||||
@ -98,14 +110,14 @@ const CurrentImagePreview = () => {
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
image={image}
|
||||
image={selectedImage && image ? image : undefined}
|
||||
onDrop={handleDrop}
|
||||
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
|
||||
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
|
||||
isUploadDisabled={true}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
{shouldShowImageDetails && image && (
|
||||
{shouldShowImageDetails && image && selectedImage && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
@ -119,7 +131,7 @@ const CurrentImagePreview = () => {
|
||||
<ImageMetadataViewer image={image} />
|
||||
</Box>
|
||||
)}
|
||||
{!shouldShowImageDetails && image && (
|
||||
{!shouldShowImageDetails && image && selectedImage && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
|
@ -2,7 +2,14 @@ import { Box, Flex, Icon, Image, MenuItem, MenuList } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useContext, useState } from 'react';
|
||||
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
|
||||
import {
|
||||
FaCheck,
|
||||
FaExpand,
|
||||
FaFolder,
|
||||
FaImage,
|
||||
FaShare,
|
||||
FaTrash,
|
||||
} from 'react-icons/fa';
|
||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||
import {
|
||||
resizeAndScaleCanvas,
|
||||
@ -27,6 +34,8 @@ import { useAppToaster } from 'app/components/Toaster';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { useDraggable } from '@dnd-kit/core';
|
||||
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
||||
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
|
||||
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
|
||||
|
||||
export const selector = createSelector(
|
||||
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
||||
@ -62,17 +71,10 @@ interface HoverableImageProps {
|
||||
isSelected: boolean;
|
||||
}
|
||||
|
||||
const memoEqualityCheck = (
|
||||
prev: HoverableImageProps,
|
||||
next: HoverableImageProps
|
||||
) =>
|
||||
prev.image.image_name === next.image.image_name &&
|
||||
prev.isSelected === next.isSelected;
|
||||
|
||||
/**
|
||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||
*/
|
||||
const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const HoverableImage = (props: HoverableImageProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
activeTabName,
|
||||
@ -93,6 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const { onDelete } = useContext(DeleteImageContext);
|
||||
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
|
||||
const handleDelete = useCallback(() => {
|
||||
onDelete(image);
|
||||
}, [image, onDelete]);
|
||||
@ -106,11 +109,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
},
|
||||
});
|
||||
|
||||
const [removeFromBoard] = useRemoveImageFromBoardMutation();
|
||||
|
||||
const handleMouseOver = () => setIsHovered(true);
|
||||
const handleMouseOut = () => setIsHovered(false);
|
||||
|
||||
const handleSelectImage = useCallback(() => {
|
||||
dispatch(imageSelected(image));
|
||||
dispatch(imageSelected(image.image_name));
|
||||
}, [image, dispatch]);
|
||||
|
||||
// Recall parameters handlers
|
||||
@ -168,6 +173,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
// dispatch(setIsLightboxOpen(true));
|
||||
};
|
||||
|
||||
const handleAddToBoard = useCallback(() => {
|
||||
onClickAddToBoard(image);
|
||||
}, [image, onClickAddToBoard]);
|
||||
|
||||
const handleRemoveFromBoard = useCallback(() => {
|
||||
if (!image.board_id) {
|
||||
return;
|
||||
}
|
||||
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
|
||||
}, [image.board_id, image.image_name, removeFromBoard]);
|
||||
|
||||
const handleOpenInNewTab = () => {
|
||||
window.open(image.image_url, '_blank');
|
||||
};
|
||||
@ -244,6 +260,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
|
||||
{image.board_id ? 'Change Board' : 'Add to Board'}
|
||||
</MenuItem>
|
||||
{image.board_id && (
|
||||
<MenuItem
|
||||
icon={<FaFolder />}
|
||||
onClickCapture={handleRemoveFromBoard}
|
||||
>
|
||||
Remove from Board
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
sx={{ color: 'error.300' }}
|
||||
icon={<FaTrash />}
|
||||
@ -339,8 +366,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
</ContextMenu>
|
||||
</Box>
|
||||
);
|
||||
}, memoEqualityCheck);
|
||||
};
|
||||
|
||||
HoverableImage.displayName = 'HoverableImage';
|
||||
|
||||
export default HoverableImage;
|
||||
export default memo(HoverableImage);
|
||||
|
@ -1,12 +1,15 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Grid,
|
||||
Icon,
|
||||
Text,
|
||||
VStack,
|
||||
forwardRef,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
@ -20,6 +23,7 @@ import {
|
||||
setGalleryImageObjectFit,
|
||||
setShouldAutoSwitchToNewImages,
|
||||
setShouldUseSingleGalleryColumn,
|
||||
setGalleryView,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
@ -53,41 +57,51 @@ import {
|
||||
selectImagesAll,
|
||||
} from '../store/imagesSlice';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
import BoardsList from './Boards/BoardsList';
|
||||
import { boardsSelector } from '../store/boardSlice';
|
||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { useListAllBoardsQuery } from 'services/apiSlice';
|
||||
|
||||
const categorySelector = createSelector(
|
||||
const itemSelector = createSelector(
|
||||
[(state: RootState) => state],
|
||||
(state) => {
|
||||
const { images } = state;
|
||||
const { categories } = images;
|
||||
const { categories, total: allImagesTotal, isLoading } = state.images;
|
||||
const { selectedBoardId } = state.boards;
|
||||
|
||||
const allImages = selectImagesAll(state);
|
||||
const filteredImages = allImages.filter((i) =>
|
||||
categories.includes(i.image_category)
|
||||
);
|
||||
|
||||
const images = allImages.filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
const isInSelectedBoard = selectedBoardId
|
||||
? i.board_id === selectedBoardId
|
||||
: true;
|
||||
return isInCategory && isInSelectedBoard;
|
||||
});
|
||||
|
||||
return {
|
||||
images: filteredImages,
|
||||
isLoading: images.isLoading,
|
||||
areMoreImagesAvailable: filteredImages.length < images.total,
|
||||
categories: images.categories,
|
||||
images,
|
||||
allImagesTotal,
|
||||
isLoading,
|
||||
categories,
|
||||
selectedBoardId,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const mainSelector = createSelector(
|
||||
[gallerySelector, uiSelector],
|
||||
(gallery, ui) => {
|
||||
[gallerySelector, uiSelector, boardsSelector],
|
||||
(gallery, ui, boards) => {
|
||||
const {
|
||||
galleryImageMinimumWidth,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
galleryView,
|
||||
} = gallery;
|
||||
|
||||
const { shouldPinGallery } = ui;
|
||||
|
||||
return {
|
||||
shouldPinGallery,
|
||||
galleryImageMinimumWidth,
|
||||
@ -95,6 +109,8 @@ const mainSelector = createSelector(
|
||||
shouldAutoSwitchToNewImages,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
galleryView,
|
||||
selectedBoardId: boards.selectedBoardId,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -126,21 +142,44 @@ const ImageGalleryContent = () => {
|
||||
shouldAutoSwitchToNewImages,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
galleryView,
|
||||
} = useAppSelector(mainSelector);
|
||||
|
||||
const { images, areMoreImagesAvailable, isLoading, categories } =
|
||||
useAppSelector(categorySelector);
|
||||
const { images, isLoading, allImagesTotal, categories, selectedBoardId } =
|
||||
useAppSelector(itemSelector);
|
||||
|
||||
const { selectedBoard } = useListAllBoardsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
selectedBoard: data?.find((b) => b.board_id === selectedBoardId),
|
||||
}),
|
||||
});
|
||||
|
||||
const filteredImagesTotal = useMemo(
|
||||
() => selectedBoard?.image_count ?? allImagesTotal,
|
||||
[allImagesTotal, selectedBoard?.image_count]
|
||||
);
|
||||
|
||||
const areMoreAvailable = useMemo(() => {
|
||||
return images.length < filteredImagesTotal;
|
||||
}, [images.length, filteredImagesTotal]);
|
||||
|
||||
const handleLoadMoreImages = useCallback(() => {
|
||||
dispatch(receivedPageOfImages());
|
||||
}, [dispatch]);
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories,
|
||||
boardId: selectedBoardId,
|
||||
})
|
||||
);
|
||||
}, [categories, dispatch, selectedBoardId]);
|
||||
|
||||
const handleEndReached = useMemo(() => {
|
||||
if (areMoreImagesAvailable && !isLoading) {
|
||||
if (areMoreAvailable && !isLoading) {
|
||||
return handleLoadMoreImages;
|
||||
}
|
||||
return undefined;
|
||||
}, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]);
|
||||
}, [areMoreAvailable, handleLoadMoreImages, isLoading]);
|
||||
|
||||
const { isOpen: isBoardListOpen, onToggle } = useDisclosure();
|
||||
|
||||
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
||||
dispatch(setGalleryImageMinimumWidth(v));
|
||||
@ -172,46 +211,79 @@ const ImageGalleryContent = () => {
|
||||
|
||||
const handleClickImagesCategory = useCallback(() => {
|
||||
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
|
||||
dispatch(setGalleryView('images'));
|
||||
}, [dispatch]);
|
||||
|
||||
const handleClickAssetsCategory = useCallback(() => {
|
||||
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
|
||||
dispatch(setGalleryView('assets'));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
<VStack
|
||||
sx={{
|
||||
gap: 2,
|
||||
flexDirection: 'column',
|
||||
h: 'full',
|
||||
w: 'full',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
ref={resizeObserverRef}
|
||||
alignItems="center"
|
||||
justifyContent="space-between"
|
||||
>
|
||||
<ButtonGroup isAttached>
|
||||
<IAIIconButton
|
||||
tooltip={t('gallery.images')}
|
||||
aria-label={t('gallery.images')}
|
||||
onClick={handleClickImagesCategory}
|
||||
isChecked={categories === IMAGE_CATEGORIES}
|
||||
<Box sx={{ w: 'full' }}>
|
||||
<Flex
|
||||
ref={resizeObserverRef}
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
<ButtonGroup isAttached>
|
||||
<IAIIconButton
|
||||
tooltip={t('gallery.images')}
|
||||
aria-label={t('gallery.images')}
|
||||
onClick={handleClickImagesCategory}
|
||||
isChecked={galleryView === 'images'}
|
||||
size="sm"
|
||||
icon={<FaImage />}
|
||||
/>
|
||||
<IAIIconButton
|
||||
tooltip={t('gallery.assets')}
|
||||
aria-label={t('gallery.assets')}
|
||||
onClick={handleClickAssetsCategory}
|
||||
isChecked={galleryView === 'assets'}
|
||||
size="sm"
|
||||
icon={<FaServer />}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<Flex
|
||||
as={Button}
|
||||
onClick={onToggle}
|
||||
size="sm"
|
||||
icon={<FaImage />}
|
||||
/>
|
||||
<IAIIconButton
|
||||
tooltip={t('gallery.assets')}
|
||||
aria-label={t('gallery.assets')}
|
||||
onClick={handleClickAssetsCategory}
|
||||
isChecked={categories === ASSETS_CATEGORIES}
|
||||
size="sm"
|
||||
icon={<FaServer />}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<Flex gap={2}>
|
||||
variant="ghost"
|
||||
sx={{
|
||||
w: 'full',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
px: 2,
|
||||
_hover: {
|
||||
bg: 'base.800',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Text
|
||||
noOfLines={1}
|
||||
sx={{ w: 'full', color: 'base.200', fontWeight: 600 }}
|
||||
>
|
||||
{selectedBoard ? selectedBoard.board_name : 'All Images'}
|
||||
</Text>
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
transform: isBoardListOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
@ -269,9 +341,12 @@ const ImageGalleryContent = () => {
|
||||
icon={shouldPinGallery ? <BsPinAngleFill /> : <BsPinAngle />}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex direction="column" gap={2} h="full">
|
||||
{images.length || areMoreImagesAvailable ? (
|
||||
<Box>
|
||||
<BoardsList isOpen={isBoardListOpen} />
|
||||
</Box>
|
||||
</Box>
|
||||
<Flex direction="column" gap={2} h="full" w="full">
|
||||
{images.length || areMoreAvailable ? (
|
||||
<>
|
||||
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
|
||||
{shouldUseSingleGalleryColumn ? (
|
||||
@ -280,14 +355,12 @@ const ImageGalleryContent = () => {
|
||||
data={images}
|
||||
endReached={handleEndReached}
|
||||
scrollerRef={(ref) => setScrollerRef(ref)}
|
||||
itemContent={(index, image) => (
|
||||
itemContent={(index, item) => (
|
||||
<Flex sx={{ pb: 2 }}>
|
||||
<HoverableImage
|
||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||
image={image}
|
||||
isSelected={
|
||||
selectedImage?.image_name === image?.image_name
|
||||
}
|
||||
key={`${item.image_name}-${item.thumbnail_url}`}
|
||||
image={item}
|
||||
isSelected={selectedImage === item?.image_name}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
@ -302,13 +375,11 @@ const ImageGalleryContent = () => {
|
||||
List: ListContainer,
|
||||
}}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={(index, image) => (
|
||||
itemContent={(index, item) => (
|
||||
<HoverableImage
|
||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||
image={image}
|
||||
isSelected={
|
||||
selectedImage?.image_name === image?.image_name
|
||||
}
|
||||
key={`${item.image_name}-${item.thumbnail_url}`}
|
||||
image={item}
|
||||
isSelected={selectedImage === item?.image_name}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
@ -316,12 +387,12 @@ const ImageGalleryContent = () => {
|
||||
</Box>
|
||||
<IAIButton
|
||||
onClick={handleLoadMoreImages}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isDisabled={!areMoreAvailable}
|
||||
isLoading={isLoading}
|
||||
loadingText="Loading"
|
||||
flexShrink={0}
|
||||
>
|
||||
{areMoreImagesAvailable
|
||||
{areMoreAvailable
|
||||
? t('gallery.loadMore')
|
||||
: t('gallery.allImagesLoaded')}
|
||||
</IAIButton>
|
||||
@ -350,7 +421,7 @@ const ImageGalleryContent = () => {
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -93,19 +93,11 @@ type ImageMetadataViewerProps = {
|
||||
image: ImageDTO;
|
||||
};
|
||||
|
||||
// TODO: I don't know if this is needed.
|
||||
const memoEqualityCheck = (
|
||||
prev: ImageMetadataViewerProps,
|
||||
next: ImageMetadataViewerProps
|
||||
) => prev.image.image_name === next.image.image_name;
|
||||
|
||||
// TODO: Show more interesting information in this component.
|
||||
|
||||
/**
|
||||
* Image metadata viewer overlays currently selected image and provides
|
||||
* access to any of its metadata for use in processing.
|
||||
*/
|
||||
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
recallBothPrompts,
|
||||
@ -333,8 +325,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}, memoEqualityCheck);
|
||||
};
|
||||
|
||||
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
|
||||
|
||||
export default ImageMetadataViewer;
|
||||
export default memo(ImageMetadataViewer);
|
||||
|
@ -42,7 +42,7 @@ export const nextPrevImageButtonsSelector = createSelector(
|
||||
}
|
||||
|
||||
const currentImageIndex = filteredImageIds.findIndex(
|
||||
(i) => i === selectedImage.image_name
|
||||
(i) => i === selectedImage
|
||||
);
|
||||
|
||||
const nextImageIndex = clamp(
|
||||
@ -71,6 +71,8 @@ export const nextPrevImageButtonsSelector = createSelector(
|
||||
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
|
||||
nextImage,
|
||||
prevImage,
|
||||
nextImageId,
|
||||
prevImageId,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -84,7 +86,7 @@ const NextPrevImageButtons = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { isOnFirstImage, isOnLastImage, nextImage, prevImage } =
|
||||
const { isOnFirstImage, isOnLastImage, nextImageId, prevImageId } =
|
||||
useAppSelector(nextPrevImageButtonsSelector);
|
||||
|
||||
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
|
||||
@ -99,19 +101,19 @@ const NextPrevImageButtons = () => {
|
||||
}, []);
|
||||
|
||||
const handlePrevImage = useCallback(() => {
|
||||
dispatch(imageSelected(prevImage));
|
||||
}, [dispatch, prevImage]);
|
||||
dispatch(imageSelected(prevImageId));
|
||||
}, [dispatch, prevImageId]);
|
||||
|
||||
const handleNextImage = useCallback(() => {
|
||||
dispatch(imageSelected(nextImage));
|
||||
}, [dispatch, nextImage]);
|
||||
dispatch(imageSelected(nextImageId));
|
||||
}, [dispatch, nextImageId]);
|
||||
|
||||
useHotkeys(
|
||||
'left',
|
||||
() => {
|
||||
handlePrevImage();
|
||||
},
|
||||
[prevImage]
|
||||
[prevImageId]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@ -119,7 +121,7 @@ const NextPrevImageButtons = () => {
|
||||
() => {
|
||||
handleNextImage();
|
||||
},
|
||||
[nextImage]
|
||||
[nextImageId]
|
||||
);
|
||||
|
||||
return (
|
||||
|
@ -0,0 +1,26 @@
|
||||
import { motion } from 'framer-motion';
|
||||
|
||||
export const SelectedItemOverlay = () => (
|
||||
<motion.div
|
||||
initial={{
|
||||
opacity: 0,
|
||||
}}
|
||||
animate={{
|
||||
opacity: 1,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineStart: 0,
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
boxShadow: 'inset 0px 0px 0px 2px var(--invokeai-colors-accent-300)',
|
||||
borderRadius: 'var(--invokeai-radii-base)',
|
||||
}}
|
||||
/>
|
||||
);
|
@ -0,0 +1,23 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { selectBoardsAll } from './boardSlice';
|
||||
|
||||
export const boardSelector = (state: RootState) => state.boards.entities;
|
||||
|
||||
export const searchBoardsSelector = createSelector(
|
||||
(state: RootState) => state,
|
||||
(state) => {
|
||||
const {
|
||||
boards: { searchText },
|
||||
} = state;
|
||||
|
||||
if (!searchText) {
|
||||
// If no search text provided, return all entities
|
||||
return selectBoardsAll(state);
|
||||
}
|
||||
|
||||
return selectBoardsAll(state).filter((i) =>
|
||||
i.board_name.toLowerCase().includes(searchText.toLowerCase())
|
||||
);
|
||||
}
|
||||
);
|
@ -0,0 +1,47 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
type BoardsState = {
|
||||
searchText: string;
|
||||
selectedBoardId?: string;
|
||||
updateBoardModalOpen: boolean;
|
||||
};
|
||||
|
||||
export const initialBoardsState: BoardsState = {
|
||||
updateBoardModalOpen: false,
|
||||
searchText: '',
|
||||
};
|
||||
|
||||
const boardsSlice = createSlice({
|
||||
name: 'boards',
|
||||
initialState: initialBoardsState,
|
||||
reducers: {
|
||||
boardIdSelected: (state, action: PayloadAction<string | undefined>) => {
|
||||
state.selectedBoardId = action.payload;
|
||||
},
|
||||
setBoardSearchText: (state, action: PayloadAction<string>) => {
|
||||
state.searchText = action.payload;
|
||||
},
|
||||
setUpdateBoardModalOpen: (state, action: PayloadAction<boolean>) => {
|
||||
state.updateBoardModalOpen = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(
|
||||
api.endpoints.deleteBoard.matchFulfilled,
|
||||
(state, action) => {
|
||||
if (action.meta.arg.originalArgs === state.selectedBoardId) {
|
||||
state.selectedBoardId = undefined;
|
||||
}
|
||||
}
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } =
|
||||
boardsSlice.actions;
|
||||
|
||||
export const boardsSelector = (state: RootState) => state.boards;
|
||||
|
||||
export default boardsSlice.reducer;
|
@ -1,17 +1,16 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { imageUpserted } from './imagesSlice';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
|
||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||
|
||||
export interface GalleryState {
|
||||
selectedImage?: ImageDTO;
|
||||
selectedImage?: string;
|
||||
galleryImageMinimumWidth: number;
|
||||
galleryImageObjectFit: GalleryImageObjectFitType;
|
||||
shouldAutoSwitchToNewImages: boolean;
|
||||
shouldUseSingleGalleryColumn: boolean;
|
||||
galleryView: 'images' | 'assets' | 'boards';
|
||||
}
|
||||
|
||||
export const initialGalleryState: GalleryState = {
|
||||
@ -19,13 +18,14 @@ export const initialGalleryState: GalleryState = {
|
||||
galleryImageObjectFit: 'cover',
|
||||
shouldAutoSwitchToNewImages: true,
|
||||
shouldUseSingleGalleryColumn: false,
|
||||
galleryView: 'images',
|
||||
};
|
||||
|
||||
export const gallerySlice = createSlice({
|
||||
name: 'gallery',
|
||||
initialState: initialGalleryState,
|
||||
reducers: {
|
||||
imageSelected: (state, action: PayloadAction<ImageDTO | undefined>) => {
|
||||
imageSelected: (state, action: PayloadAction<string | undefined>) => {
|
||||
state.selectedImage = action.payload;
|
||||
// TODO: if the user selects an image, disable the auto switch?
|
||||
// state.shouldAutoSwitchToNewImages = false;
|
||||
@ -48,6 +48,12 @@ export const gallerySlice = createSlice({
|
||||
) => {
|
||||
state.shouldUseSingleGalleryColumn = action.payload;
|
||||
},
|
||||
setGalleryView: (
|
||||
state,
|
||||
action: PayloadAction<'images' | 'assets' | 'boards'>
|
||||
) => {
|
||||
state.galleryView = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(imageUpserted, (state, action) => {
|
||||
@ -55,17 +61,17 @@ export const gallerySlice = createSlice({
|
||||
state.shouldAutoSwitchToNewImages &&
|
||||
action.payload.image_category === 'general'
|
||||
) {
|
||||
state.selectedImage = action.payload;
|
||||
state.selectedImage = action.payload.image_name;
|
||||
}
|
||||
});
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
if (state.selectedImage?.image_name === image_name) {
|
||||
state.selectedImage.image_url = image_url;
|
||||
state.selectedImage.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
});
|
||||
// if (state.selectedImage?.image_name === image_name) {
|
||||
// state.selectedImage.image_url = image_url;
|
||||
// state.selectedImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// });
|
||||
},
|
||||
});
|
||||
|
||||
@ -75,6 +81,7 @@ export const {
|
||||
setGalleryImageObjectFit,
|
||||
setShouldAutoSwitchToNewImages,
|
||||
setShouldUseSingleGalleryColumn,
|
||||
setGalleryView,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
export default gallerySlice.reducer;
|
||||
|
@ -11,7 +11,6 @@ import { dateComparator } from 'common/util/dateComparator';
|
||||
import { keyBy } from 'lodash-es';
|
||||
import {
|
||||
imageDeleted,
|
||||
imageMetadataReceived,
|
||||
imageUrlsReceived,
|
||||
receivedPageOfImages,
|
||||
} from 'services/thunks/image';
|
||||
@ -74,11 +73,21 @@ const imagesSlice = createSlice({
|
||||
});
|
||||
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
|
||||
state.isLoading = false;
|
||||
const { boardId, categories, imageOrigin, isIntermediate } =
|
||||
action.meta.arg;
|
||||
|
||||
const { items, offset, limit, total } = action.payload;
|
||||
imagesAdapter.upsertMany(state, items);
|
||||
|
||||
if (!categories?.includes('general') || boardId) {
|
||||
// need to skip updating the total images count if the images recieved were for a specific board
|
||||
// TODO: this doesn't work when on the Asset tab/category...
|
||||
return;
|
||||
}
|
||||
|
||||
state.offset = offset;
|
||||
state.limit = limit;
|
||||
state.total = total;
|
||||
imagesAdapter.upsertMany(state, items);
|
||||
});
|
||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||
// Image deleted
|
||||
@ -154,3 +163,16 @@ export const selectFilteredImagesIds = createSelector(
|
||||
.map((i) => i.image_name);
|
||||
}
|
||||
);
|
||||
|
||||
// export const selectImageById = createSelector(
|
||||
// (state: RootState, imageId) => state,
|
||||
// (state) => {
|
||||
// const {
|
||||
// images: { categories },
|
||||
// } = state;
|
||||
|
||||
// return selectImagesAll(state)
|
||||
// .filter((i) => categories.includes(i.image_category))
|
||||
// .map((i) => i.image_name);
|
||||
// }
|
||||
// );
|
||||
|
@ -11,6 +11,8 @@ import { FieldComponentProps } from './types';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
|
||||
const ImageInputFieldComponent = (
|
||||
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
||||
@ -19,9 +21,16 @@ const ImageInputFieldComponent = (
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const {
|
||||
data: image,
|
||||
isLoading,
|
||||
isError,
|
||||
isSuccess,
|
||||
} = useGetImageDTOQuery(field.value ?? skipToken);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (field.value?.image_name === droppedImage.image_name) {
|
||||
if (field.value === droppedImage.image_name) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -29,11 +38,11 @@ const ImageInputFieldComponent = (
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: droppedImage,
|
||||
value: droppedImage.image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, field.value?.image_name, nodeId]
|
||||
[dispatch, field.name, field.value, nodeId]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
@ -56,7 +65,7 @@ const ImageInputFieldComponent = (
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
image={field.value}
|
||||
image={image}
|
||||
onDrop={handleDrop}
|
||||
onReset={handleReset}
|
||||
resetIconSize="sm"
|
||||
|
@ -1,28 +1,18 @@
|
||||
import { Select } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ModelInputFieldTemplate,
|
||||
ModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { selectModelsIds } from 'features/system/store/modelSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const availableModelsSelector = createSelector(
|
||||
[selectModelsIds],
|
||||
(allModelNames) => {
|
||||
return { allModelNames };
|
||||
// return map(modelList, (_, name) => name);
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListModelsQuery } from 'services/apiSlice';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||
@ -30,28 +20,82 @@ const ModelInputFieldComponent = (
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { allModelNames } = useAppSelector(availableModelsSelector);
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
model_type: 'pipeline',
|
||||
});
|
||||
|
||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: e.target.value,
|
||||
})
|
||||
);
|
||||
};
|
||||
const data = useMemo(() => {
|
||||
if (!pipelineModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(pipelineModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [pipelineModels]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
|
||||
[pipelineModels?.entities, pipelineModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: v,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && pipelineModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = pipelineModels?.ids[0];
|
||||
|
||||
if (!isString(firstModel)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleValueChanged(firstModel);
|
||||
}, [field.value, handleValueChanged, pipelineModels?.ids]);
|
||||
|
||||
return (
|
||||
<Select
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model &&
|
||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
value={field.value || allModelNames[0]}
|
||||
>
|
||||
{allModelNames.map((option) => (
|
||||
<option key={option}>{option}</option>
|
||||
))}
|
||||
</Select>
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -101,21 +101,6 @@ const nodesSlice = createSlice({
|
||||
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
||||
state.schema = action.payload;
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
state.nodes.forEach((node) => {
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (input.type === 'image') {
|
||||
if (input.value?.image_name === image_name) {
|
||||
input.value.image_url = image_url;
|
||||
input.value.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -214,7 +214,7 @@ export type VaeInputFieldValue = FieldValueBase & {
|
||||
|
||||
export type ImageInputFieldValue = FieldValueBase & {
|
||||
type: 'image';
|
||||
value?: ImageDTO;
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type ModelInputFieldValue = FieldValueBase & {
|
||||
|
@ -65,15 +65,13 @@ export const addControlNetToLinearGraph = (
|
||||
|
||||
if (processedControlImage && processorType !== 'none') {
|
||||
// We've already processed the image in the app, so we can just use the processed image
|
||||
const { image_name } = processedControlImage;
|
||||
controlNetNode.image = {
|
||||
image_name,
|
||||
image_name: processedControlImage,
|
||||
};
|
||||
} else if (controlImage) {
|
||||
// The control image is preprocessed
|
||||
const { image_name } = controlImage;
|
||||
controlNetNode.image = {
|
||||
image_name,
|
||||
image_name: controlImage,
|
||||
};
|
||||
} else {
|
||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
||||
|
@ -23,6 +23,7 @@ import {
|
||||
} from './constants';
|
||||
import { set } from 'lodash-es';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
model: modelId,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = (
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = (
|
||||
id: NOISE,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
type: 'pipeline_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
|
@ -17,6 +17,7 @@ import {
|
||||
INPAINT_GRAPH,
|
||||
INPAINT,
|
||||
} from './constants';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = (
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
model: modelId,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = (
|
||||
// We may need to set the inpaint width and height to scale the image
|
||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
|
||||
const graph: NonNullableGraph = {
|
||||
id: INPAINT_GRAPH,
|
||||
nodes: {
|
||||
@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = (
|
||||
prompt: negativePrompt,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
type: 'pipeline_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
model,
|
||||
},
|
||||
[RANGE_OF_SIZE]: {
|
||||
type: 'range_of_size',
|
||||
|
@ -14,6 +14,7 @@ import {
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
model: modelId,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = (
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = (
|
||||
steps,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
type: 'pipeline_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
|
@ -22,6 +22,7 @@ import {
|
||||
} from './constants';
|
||||
import { set } from 'lodash-es';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = (
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
model: modelId,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = (
|
||||
throw new Error('No initial image found in state');
|
||||
}
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: IMAGE_TO_IMAGE_GRAPH,
|
||||
@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = (
|
||||
id: NOISE,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
type: 'pipeline_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
@ -274,7 +277,7 @@ export const buildLinearImageToImageGraph = (
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: initialImage.image_name,
|
||||
image_name: initialImage.imageName,
|
||||
},
|
||||
is_intermediate: true,
|
||||
width,
|
||||
@ -311,7 +314,7 @@ export const buildLinearImageToImageGraph = (
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
||||
image_name: initialImage.image_name,
|
||||
image_name: initialImage.imageName,
|
||||
});
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
|
@ -1,6 +1,10 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
|
||||
import {
|
||||
BaseModelType,
|
||||
RandomIntInvocation,
|
||||
RangeOfSizeInvocation,
|
||||
} from 'services/api';
|
||||
import {
|
||||
ITERATE,
|
||||
LATENTS_TO_IMAGE,
|
||||
@ -14,6 +18,7 @@ import {
|
||||
TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
|
||||
type TextToImageGraphOverrides = {
|
||||
width: number;
|
||||
@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = (
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model: model_name,
|
||||
model: modelId,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = (
|
||||
shouldRandomizeSeed,
|
||||
} = state.generation;
|
||||
|
||||
const model = modelIdToPipelineModelField(modelId);
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = (
|
||||
steps,
|
||||
},
|
||||
[MODEL_LOADER]: {
|
||||
type: 'sd1_model_loader',
|
||||
type: 'pipeline_model_loader',
|
||||
id: MODEL_LOADER,
|
||||
model_name,
|
||||
model,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
|
@ -1,9 +1,10 @@
|
||||
import { Graph } from 'services/api';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es';
|
||||
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { InputFieldValue } from 'features/nodes/types/types';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||
|
||||
/**
|
||||
* We need to do special handling for some fields
|
||||
@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
||||
}
|
||||
}
|
||||
|
||||
if (field.type === 'model') {
|
||||
if (field.value) {
|
||||
return modelIdToPipelineModelField(field.value);
|
||||
}
|
||||
}
|
||||
|
||||
return field.value;
|
||||
};
|
||||
|
||||
|
@ -7,7 +7,7 @@ export const NOISE = 'noise';
|
||||
export const RANDOM_INT = 'rand_int';
|
||||
export const RANGE_OF_SIZE = 'range_of_size';
|
||||
export const ITERATE = 'iterate';
|
||||
export const MODEL_LOADER = 'model_loader';
|
||||
export const MODEL_LOADER = 'pipeline_model_loader';
|
||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||
export const RESIZE = 'resize_image';
|
||||
|
@ -0,0 +1,18 @@
|
||||
import { BaseModelType, PipelineModelField } from 'services/api';
|
||||
|
||||
/**
|
||||
* Crudely converts a model id to a pipeline model field
|
||||
* TODO: Make better
|
||||
*/
|
||||
export const modelIdToPipelineModelField = (
|
||||
modelId: string
|
||||
): PipelineModelField => {
|
||||
const [base_model, model_type, model_name] = modelId.split('/');
|
||||
|
||||
const field: PipelineModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -57,7 +57,7 @@ export const buildImg2ImgNode = (
|
||||
}
|
||||
|
||||
imageToImageNode.image = {
|
||||
image_name: initialImage.image_name,
|
||||
image_name: initialImage.imageName,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,7 @@ import ParamScheduler from './ParamScheduler';
|
||||
const ParamSchedulerAndModel = () => {
|
||||
return (
|
||||
<Flex gap={3} w="full">
|
||||
<Box w="20rem">
|
||||
<Box w="25rem">
|
||||
<ParamScheduler />
|
||||
</Box>
|
||||
<Box w="full">
|
||||
|
@ -10,7 +10,9 @@ import { generationSelector } from 'features/parameters/store/generationSelector
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
|
||||
const selector = createSelector(
|
||||
[generationSelector],
|
||||
@ -27,14 +29,21 @@ const InitialImagePreview = () => {
|
||||
const { initialImage } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const {
|
||||
data: image,
|
||||
isLoading,
|
||||
isError,
|
||||
isSuccess,
|
||||
} = useGetImageDTOQuery(initialImage?.imageName ?? skipToken);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (droppedImage.image_name === initialImage?.image_name) {
|
||||
if (droppedImage.image_name === initialImage?.imageName) {
|
||||
return;
|
||||
}
|
||||
dispatch(initialImageChanged(droppedImage));
|
||||
},
|
||||
[dispatch, initialImage?.image_name]
|
||||
[dispatch, initialImage]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
@ -53,10 +62,10 @@ const InitialImagePreview = () => {
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
image={initialImage}
|
||||
image={image}
|
||||
onDrop={handleDrop}
|
||||
onReset={handleReset}
|
||||
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
|
||||
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
|
||||
postUploadAction={{ type: 'SET_INITIAL_IMAGE' }}
|
||||
withResetIcon
|
||||
/>
|
||||
|
@ -1,10 +1,9 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
||||
import { configChanged } from 'features/system/store/configSlice';
|
||||
import { clamp, sortBy } from 'lodash-es';
|
||||
import { clamp } from 'lodash-es';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import {
|
||||
CfgScaleParam,
|
||||
HeightParam,
|
||||
@ -17,14 +16,13 @@ import {
|
||||
StrengthParam,
|
||||
WidthParam,
|
||||
} from './parameterZodSchemas';
|
||||
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
||||
|
||||
export interface GenerationState {
|
||||
cfgScale: CfgScaleParam;
|
||||
height: HeightParam;
|
||||
img2imgStrength: StrengthParam;
|
||||
infillMethod: string;
|
||||
initialImage?: ImageDTO;
|
||||
initialImage?: { imageName: string; width: number; height: number };
|
||||
iterations: number;
|
||||
perlin: number;
|
||||
positivePrompt: PositivePromptParam;
|
||||
@ -212,35 +210,20 @@ export const generationSlice = createSlice({
|
||||
state.shouldUseNoiseSettings = action.payload;
|
||||
},
|
||||
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
||||
state.initialImage = action.payload;
|
||||
const { image_name, width, height } = action.payload;
|
||||
state.initialImage = { imageName: image_name, width, height };
|
||||
},
|
||||
modelSelected: (state, action: PayloadAction<string>) => {
|
||||
state.model = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
||||
if (!state.model) {
|
||||
const firstModel = sortBy(action.payload, 'name')[0];
|
||||
state.model = firstModel.name;
|
||||
}
|
||||
});
|
||||
|
||||
builder.addCase(configChanged, (state, action) => {
|
||||
const defaultModel = action.payload.sd?.defaultModel;
|
||||
if (defaultModel && !state.model) {
|
||||
state.model = defaultModel;
|
||||
}
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
if (state.initialImage?.image_name === image_name) {
|
||||
state.initialImage.image_url = image_url;
|
||||
state.initialImage.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -154,3 +154,17 @@ export type StrengthParam = z.infer<typeof zStrength>;
|
||||
*/
|
||||
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
||||
zStrength.safeParse(val).success;
|
||||
|
||||
// /**
|
||||
// * Zod schema for BaseModelType
|
||||
// */
|
||||
// export const zBaseModelType = z.enum(['sd-1', 'sd-2']);
|
||||
// /**
|
||||
// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI.
|
||||
// */
|
||||
// export type BaseModelType = z.infer<typeof zBaseModelType>;
|
||||
// /**
|
||||
// * Validates/type-guards a value as a base model type
|
||||
// */
|
||||
// export const isValidBaseModelType = (val: unknown): val is BaseModelType =>
|
||||
// zBaseModelType.safeParse(val).success;
|
||||
|
@ -1,44 +1,59 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect, {
|
||||
IAISelectDataType,
|
||||
} from 'common/components/IAIMantineSelect';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
[(state: RootState) => state, generationSelector],
|
||||
(state, generation) => {
|
||||
const selectedModel = selectModelsById(state, generation.model);
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useListModelsQuery } from 'services/apiSlice';
|
||||
|
||||
const modelData = selectModelsAll(state)
|
||||
.map<IAISelectDataType>((m) => ({
|
||||
value: m.name,
|
||||
label: m.name,
|
||||
}))
|
||||
.sort((a, b) => a.label.localeCompare(b.label));
|
||||
return {
|
||||
selectedModel,
|
||||
modelData,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
export const MODEL_TYPE_MAP = {
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
'sd-2': 'Stable Diffusion 2.x',
|
||||
};
|
||||
|
||||
const ModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { selectedModel, modelData } = useAppSelector(selector);
|
||||
|
||||
const selectedModelId = useAppSelector(
|
||||
(state: RootState) => state.generation.model
|
||||
);
|
||||
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
model_type: 'pipeline',
|
||||
});
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!pipelineModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(pipelineModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [pipelineModels]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => pipelineModels?.entities[selectedModelId],
|
||||
[pipelineModels?.entities, selectedModelId]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
@ -49,13 +64,27 @@ const ModelSelect = () => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = pipelineModels?.ids[0];
|
||||
|
||||
if (!isString(firstModel)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleChangeModel(firstModel);
|
||||
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={t('modelManager.model')}
|
||||
value={selectedModel?.name ?? ''}
|
||||
value={selectedModelId}
|
||||
placeholder="Pick one"
|
||||
data={modelData}
|
||||
data={data}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
);
|
||||
|
@ -1,6 +1,5 @@
|
||||
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
||||
import { RootState } from 'app/store/store';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||
@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({
|
||||
|
||||
export default function SettingsSchedulers() {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const enabledSchedulers = useAppSelector(
|
||||
|
@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors';
|
||||
const isApplicationReadySelector = createSelector(
|
||||
[systemSelector, configSelector],
|
||||
(system, config) => {
|
||||
const { wereModelsReceived, wasSchemaParsed } = system;
|
||||
const { wasSchemaParsed } = system;
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
return {
|
||||
disabledTabs,
|
||||
wereModelsReceived,
|
||||
wasSchemaParsed,
|
||||
};
|
||||
}
|
||||
@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector(
|
||||
* Checks if the application is ready to be used, i.e. if the initial startup process is finished.
|
||||
*/
|
||||
export const useIsApplicationReady = () => {
|
||||
const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector(
|
||||
const { disabledTabs, wasSchemaParsed } = useAppSelector(
|
||||
isApplicationReadySelector
|
||||
);
|
||||
|
||||
const isApplicationReady = useMemo(() => {
|
||||
if (!wereModelsReceived) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}, [disabledTabs, wereModelsReceived, wasSchemaParsed]);
|
||||
}, [disabledTabs, wasSchemaParsed]);
|
||||
|
||||
return isApplicationReady;
|
||||
};
|
||||
|
@ -1,3 +0,0 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
|
||||
export const modelSelector = (state: RootState) => state.models;
|
@ -1,47 +0,0 @@
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
|
||||
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
|
||||
name: string;
|
||||
};
|
||||
|
||||
export const modelsAdapter = createEntityAdapter<Model>({
|
||||
selectId: (model) => model.name,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
|
||||
export const initialModelsState = modelsAdapter.getInitialState();
|
||||
|
||||
export type ModelsState = typeof initialModelsState;
|
||||
|
||||
export const modelsSlice = createSlice({
|
||||
name: 'models',
|
||||
initialState: initialModelsState,
|
||||
reducers: {
|
||||
modelAdded: modelsAdapter.upsertOne,
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Received Models - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
||||
const models = action.payload;
|
||||
modelsAdapter.setAll(state, models);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectModelsAll,
|
||||
selectById: selectModelsById,
|
||||
selectEntities: selectModelsEntities,
|
||||
selectIds: selectModelsIds,
|
||||
selectTotal: selectModelsTotal,
|
||||
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
|
||||
|
||||
export const { modelAdded } = modelsSlice.actions;
|
||||
|
||||
export default modelsSlice.reducer;
|
@ -1,6 +0,0 @@
|
||||
import { ModelsState } from './modelSlice';
|
||||
|
||||
/**
|
||||
* Models slice persist denylist
|
||||
*/
|
||||
export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids'];
|
@ -1,20 +1,12 @@
|
||||
import { UseToastOptions } from '@chakra-ui/react';
|
||||
import { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/types/invokeai';
|
||||
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { makeToast } from '../../../app/components/Toaster';
|
||||
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||
import { LogLevelName } from 'roarr';
|
||||
import { InvokeLogLevel } from 'app/logging/useLogger';
|
||||
import { TFuncKey } from 'i18next';
|
||||
import { t } from 'i18next';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { LANGUAGES } from '../components/LanguagePicker';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||
import { TFuncKey, t } from 'i18next';
|
||||
import { LogLevelName } from 'roarr';
|
||||
import {
|
||||
appSocketConnected,
|
||||
appSocketDisconnected,
|
||||
@ -26,6 +18,11 @@ import {
|
||||
appSocketSubscribed,
|
||||
appSocketUnsubscribed,
|
||||
} from 'services/events/actions';
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
||||
import { makeToast } from '../../../app/components/Toaster';
|
||||
import { LANGUAGES } from '../components/LanguagePicker';
|
||||
|
||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||
|
||||
@ -95,6 +92,7 @@ export interface SystemState {
|
||||
shouldAntialiasProgressImage: boolean;
|
||||
language: keyof typeof LANGUAGES;
|
||||
isUploading: boolean;
|
||||
boardIdToAddTo?: string;
|
||||
}
|
||||
|
||||
export const initialSystemState: SystemState = {
|
||||
@ -225,6 +223,7 @@ export const systemSlice = createSlice({
|
||||
*/
|
||||
builder.addCase(appSocketSubscribed, (state, action) => {
|
||||
state.sessionId = action.payload.sessionId;
|
||||
state.boardIdToAddTo = action.payload.boardId;
|
||||
state.canceledSession = '';
|
||||
});
|
||||
|
||||
@ -233,6 +232,7 @@ export const systemSlice = createSlice({
|
||||
*/
|
||||
builder.addCase(appSocketUnsubscribed, (state) => {
|
||||
state.sessionId = null;
|
||||
state.boardIdToAddTo = undefined;
|
||||
});
|
||||
|
||||
/**
|
||||
@ -376,13 +376,6 @@ export const systemSlice = createSlice({
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Received available models from the backend
|
||||
*/
|
||||
builder.addCase(receivedModels.fulfilled, (state) => {
|
||||
state.wereModelsReceived = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* OpenAPI schema was parsed
|
||||
*/
|
||||
|
@ -8,6 +8,10 @@ export type { OpenAPIConfig } from './core/OpenAPI';
|
||||
|
||||
export type { AddInvocation } from './models/AddInvocation';
|
||||
export type { BaseModelType } from './models/BaseModelType';
|
||||
export type { BoardChanges } from './models/BoardChanges';
|
||||
export type { BoardDTO } from './models/BoardDTO';
|
||||
export type { Body_create_board_image } from './models/Body_create_board_image';
|
||||
export type { Body_remove_board_image } from './models/Body_remove_board_image';
|
||||
export type { Body_upload_image } from './models/Body_upload_image';
|
||||
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
|
||||
export type { CkptModelInfo } from './models/CkptModelInfo';
|
||||
@ -21,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField';
|
||||
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
|
||||
export type { ControlField } from './models/ControlField';
|
||||
export type { ControlNetInvocation } from './models/ControlNetInvocation';
|
||||
export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
|
||||
export type { ControlNetModelFormat } from './models/ControlNetModelFormat';
|
||||
export type { ControlOutput } from './models/ControlOutput';
|
||||
export type { CreateModelRequest } from './models/CreateModelRequest';
|
||||
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
||||
@ -63,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
|
||||
export type { InpaintInvocation } from './models/InpaintInvocation';
|
||||
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
||||
export type { IntOutput } from './models/IntOutput';
|
||||
export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
|
||||
export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config';
|
||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
|
||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
|
||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
|
||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
|
||||
export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
|
||||
export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config';
|
||||
export type { IterateInvocation } from './models/IterateInvocation';
|
||||
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
||||
export type { LatentsField } from './models/LatentsField';
|
||||
@ -83,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
||||
export type { LoraInfo } from './models/LoraInfo';
|
||||
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
|
||||
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
|
||||
export type { LoRAModelConfig } from './models/LoRAModelConfig';
|
||||
export type { LoRAModelFormat } from './models/LoRAModelFormat';
|
||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||
export type { MaskOutput } from './models/MaskOutput';
|
||||
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
||||
@ -98,12 +98,15 @@ export type { MultiplyInvocation } from './models/MultiplyInvocation';
|
||||
export type { NoiseInvocation } from './models/NoiseInvocation';
|
||||
export type { NoiseOutput } from './models/NoiseOutput';
|
||||
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
|
||||
export type { OffsetPaginatedResults_BoardDTO_ } from './models/OffsetPaginatedResults_BoardDTO_';
|
||||
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
|
||||
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
|
||||
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
||||
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
|
||||
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
||||
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
|
||||
export type { PipelineModelField } from './models/PipelineModelField';
|
||||
export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation';
|
||||
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
|
||||
export type { PromptOutput } from './models/PromptOutput';
|
||||
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
||||
@ -115,20 +118,28 @@ export type { ResourceOrigin } from './models/ResourceOrigin';
|
||||
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
||||
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
||||
export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
|
||||
export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
|
||||
export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
|
||||
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
||||
export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
|
||||
export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
|
||||
export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat';
|
||||
export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
|
||||
export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
|
||||
export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat';
|
||||
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
||||
export type { SubModelType } from './models/SubModelType';
|
||||
export type { SubtractInvocation } from './models/SubtractInvocation';
|
||||
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
||||
export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
|
||||
export type { UNetField } from './models/UNetField';
|
||||
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
||||
export type { VaeField } from './models/VaeField';
|
||||
export type { VaeModelConfig } from './models/VaeModelConfig';
|
||||
export type { VaeModelFormat } from './models/VaeModelFormat';
|
||||
export type { VaeRepo } from './models/VaeRepo';
|
||||
export type { ValidationError } from './models/ValidationError';
|
||||
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
||||
|
||||
export { BoardsService } from './services/BoardsService';
|
||||
export { ImagesService } from './services/ImagesService';
|
||||
export { ModelsService } from './services/ModelsService';
|
||||
export { SessionsService } from './services/SessionsService';
|
||||
|
@ -0,0 +1,15 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type BoardChanges = {
|
||||
/**
|
||||
* The board's new name.
|
||||
*/
|
||||
board_name?: string;
|
||||
/**
|
||||
* The name of the board's new cover image.
|
||||
*/
|
||||
cover_image_name?: string;
|
||||
};
|
||||
|
38
invokeai/frontend/web/src/services/api/models/BoardDTO.ts
Normal file
38
invokeai/frontend/web/src/services/api/models/BoardDTO.ts
Normal file
@ -0,0 +1,38 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* Deserialized board record with cover image URL and image count.
|
||||
*/
|
||||
export type BoardDTO = {
|
||||
/**
|
||||
* The unique ID of the board.
|
||||
*/
|
||||
board_id: string;
|
||||
/**
|
||||
* The name of the board.
|
||||
*/
|
||||
board_name: string;
|
||||
/**
|
||||
* The created timestamp of the board.
|
||||
*/
|
||||
created_at: string;
|
||||
/**
|
||||
* The updated timestamp of the board.
|
||||
*/
|
||||
updated_at: string;
|
||||
/**
|
||||
* The deleted timestamp of the board.
|
||||
*/
|
||||
deleted_at?: string;
|
||||
/**
|
||||
* The name of the board's cover image.
|
||||
*/
|
||||
cover_image_name?: string;
|
||||
/**
|
||||
* The number of images in the board.
|
||||
*/
|
||||
image_count: number;
|
||||
};
|
||||
|
@ -0,0 +1,15 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type Body_create_board_image = {
|
||||
/**
|
||||
* The id of the board to add to
|
||||
*/
|
||||
board_id: string;
|
||||
/**
|
||||
* The name of the image to add
|
||||
*/
|
||||
image_name: string;
|
||||
};
|
||||
|
@ -0,0 +1,15 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
export type Body_remove_board_image = {
|
||||
/**
|
||||
* The id of the board
|
||||
*/
|
||||
board_id: string;
|
||||
/**
|
||||
* The name of the image to remove
|
||||
*/
|
||||
image_name: string;
|
||||
};
|
||||
|
@ -0,0 +1,18 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
import type { BaseModelType } from './BaseModelType';
|
||||
import type { ControlNetModelFormat } from './ControlNetModelFormat';
|
||||
import type { ModelError } from './ModelError';
|
||||
|
||||
export type ControlNetModelConfig = {
|
||||
name: string;
|
||||
base_model: BaseModelType;
|
||||
type: 'controlnet';
|
||||
path: string;
|
||||
description?: string;
|
||||
model_format: ControlNetModelFormat;
|
||||
error?: ModelError;
|
||||
};
|
||||
|
@ -0,0 +1,8 @@
|
||||
/* istanbul ignore file */
|
||||
/* tslint:disable */
|
||||
/* eslint-disable */
|
||||
|
||||
/**
|
||||
* An enumeration.
|
||||
*/
|
||||
export type ControlNetModelFormat = 'checkpoint' | 'diffusers';
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user