mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into install-script-python-version-error-prompt-fix
This commit is contained in:
commit
df1907e849
@ -2,8 +2,17 @@
|
|||||||
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
import os
|
import os
|
||||||
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
|
SqliteBoardImageRecordStorage,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.board_images import (
|
||||||
|
BoardImagesService,
|
||||||
|
BoardImagesServiceDependencies,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
from invokeai.app.services.images import ImageService
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.metadata import CoreMetadataService
|
from invokeai.app.services.metadata import CoreMetadataService
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
@ -57,7 +66,7 @@ class ApiDependencies:
|
|||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = config.db_path
|
db_location = config.db_path
|
||||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
@ -72,14 +81,40 @@ class ApiDependencies:
|
|||||||
DiskLatentsStorage(f"{output_folder}/latents")
|
DiskLatentsStorage(f"{output_folder}/latents")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||||
|
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||||
|
|
||||||
|
boards = BoardService(
|
||||||
|
services=BoardServiceDependencies(
|
||||||
|
board_image_record_storage=board_image_record_storage,
|
||||||
|
board_record_storage=board_record_storage,
|
||||||
|
image_record_storage=image_record_storage,
|
||||||
|
url=urls,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
board_images = BoardImagesService(
|
||||||
|
services=BoardImagesServiceDependencies(
|
||||||
|
board_image_record_storage=board_image_record_storage,
|
||||||
|
board_record_storage=board_record_storage,
|
||||||
|
image_record_storage=image_record_storage,
|
||||||
|
url=urls,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
images = ImageService(
|
images = ImageService(
|
||||||
image_record_storage=image_record_storage,
|
services=ImageServiceDependencies(
|
||||||
image_file_storage=image_file_storage,
|
board_image_record_storage=board_image_record_storage,
|
||||||
metadata=metadata,
|
image_record_storage=image_record_storage,
|
||||||
url=urls,
|
image_file_storage=image_file_storage,
|
||||||
logger=logger,
|
metadata=metadata,
|
||||||
names=names,
|
url=urls,
|
||||||
graph_execution_manager=graph_execution_manager,
|
logger=logger,
|
||||||
|
names=names,
|
||||||
|
graph_execution_manager=graph_execution_manager,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
@ -87,6 +122,8 @@ class ApiDependencies:
|
|||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
|
boards=boards,
|
||||||
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
filename=db_location, table_name="graphs"
|
filename=db_location, table_name="graphs"
|
||||||
|
69
invokeai/app/api/routers/board_images.py
Normal file
69
invokeai/app/api/routers/board_images.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from fastapi import Body, HTTPException, Path, Query
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
||||||
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
|
from invokeai.app.services.models.image_record import ImageDTO
|
||||||
|
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||||
|
|
||||||
|
|
||||||
|
@board_images_router.post(
|
||||||
|
"/",
|
||||||
|
operation_id="create_board_image",
|
||||||
|
responses={
|
||||||
|
201: {"description": "The image was added to a board successfully"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
)
|
||||||
|
async def create_board_image(
|
||||||
|
board_id: str = Body(description="The id of the board to add to"),
|
||||||
|
image_name: str = Body(description="The name of the image to add"),
|
||||||
|
):
|
||||||
|
"""Creates a board_image"""
|
||||||
|
try:
|
||||||
|
result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to add to board")
|
||||||
|
|
||||||
|
@board_images_router.delete(
|
||||||
|
"/",
|
||||||
|
operation_id="remove_board_image",
|
||||||
|
responses={
|
||||||
|
201: {"description": "The image was removed from the board successfully"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
)
|
||||||
|
async def remove_board_image(
|
||||||
|
board_id: str = Body(description="The id of the board"),
|
||||||
|
image_name: str = Body(description="The name of the image to remove"),
|
||||||
|
):
|
||||||
|
"""Deletes a board_image"""
|
||||||
|
try:
|
||||||
|
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@board_images_router.get(
|
||||||
|
"/{board_id}",
|
||||||
|
operation_id="list_board_images",
|
||||||
|
response_model=OffsetPaginatedResults[ImageDTO],
|
||||||
|
)
|
||||||
|
async def list_board_images(
|
||||||
|
board_id: str = Path(description="The id of the board"),
|
||||||
|
offset: int = Query(default=0, description="The page offset"),
|
||||||
|
limit: int = Query(default=10, description="The number of boards per page"),
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
|
"""Gets a list of images for a board"""
|
||||||
|
|
||||||
|
results = ApiDependencies.invoker.services.board_images.get_images_for_board(
|
||||||
|
board_id,
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
108
invokeai/app/api/routers/boards.py
Normal file
108
invokeai/app/api/routers/boards.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
from typing import Optional, Union
|
||||||
|
from fastapi import Body, HTTPException, Path, Query
|
||||||
|
from fastapi.routing import APIRouter
|
||||||
|
from invokeai.app.services.board_record_storage import BoardChanges
|
||||||
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
|
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
|
||||||
|
|
||||||
|
|
||||||
|
@boards_router.post(
|
||||||
|
"/",
|
||||||
|
operation_id="create_board",
|
||||||
|
responses={
|
||||||
|
201: {"description": "The board was created successfully"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=BoardDTO,
|
||||||
|
)
|
||||||
|
async def create_board(
|
||||||
|
board_name: str = Query(description="The name of the board to create"),
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Creates a board"""
|
||||||
|
try:
|
||||||
|
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||||
|
|
||||||
|
|
||||||
|
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
|
||||||
|
async def get_board(
|
||||||
|
board_id: str = Path(description="The id of board to get"),
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Gets a board"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=404, detail="Board not found")
|
||||||
|
|
||||||
|
|
||||||
|
@boards_router.patch(
|
||||||
|
"/{board_id}",
|
||||||
|
operation_id="update_board",
|
||||||
|
responses={
|
||||||
|
201: {
|
||||||
|
"description": "The board was updated successfully",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=BoardDTO,
|
||||||
|
)
|
||||||
|
async def update_board(
|
||||||
|
board_id: str = Path(description="The id of board to update"),
|
||||||
|
changes: BoardChanges = Body(description="The changes to apply to the board"),
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Updates a board"""
|
||||||
|
try:
|
||||||
|
result = ApiDependencies.invoker.services.boards.update(
|
||||||
|
board_id=board_id, changes=changes
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||||
|
|
||||||
|
|
||||||
|
@boards_router.delete("/{board_id}", operation_id="delete_board")
|
||||||
|
async def delete_board(
|
||||||
|
board_id: str = Path(description="The id of board to delete"),
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a board"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||||
|
except Exception as e:
|
||||||
|
# TODO: Does this need any exception handling at all?
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@boards_router.get(
|
||||||
|
"/",
|
||||||
|
operation_id="list_boards",
|
||||||
|
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
|
||||||
|
)
|
||||||
|
async def list_boards(
|
||||||
|
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||||
|
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||||
|
limit: Optional[int] = Query(
|
||||||
|
default=None, description="The number of boards per page"
|
||||||
|
),
|
||||||
|
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||||
|
"""Gets a list of boards"""
|
||||||
|
if all:
|
||||||
|
return ApiDependencies.invoker.services.boards.get_all()
|
||||||
|
elif offset is not None and limit is not None:
|
||||||
|
return ApiDependencies.invoker.services.boards.get_many(
|
||||||
|
offset,
|
||||||
|
limit,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
|
||||||
|
)
|
@ -221,6 +221,9 @@ async def list_images_with_metadata(
|
|||||||
is_intermediate: Optional[bool] = Query(
|
is_intermediate: Optional[bool] = Query(
|
||||||
default=None, description="Whether to list intermediate images"
|
default=None, description="Whether to list intermediate images"
|
||||||
),
|
),
|
||||||
|
board_id: Optional[str] = Query(
|
||||||
|
default=None, description="The board id to filter by"
|
||||||
|
),
|
||||||
offset: int = Query(default=0, description="The page offset"),
|
offset: int = Query(default=0, description="The page offset"),
|
||||||
limit: int = Query(default=10, description="The number of images per page"),
|
limit: int = Query(default=10, description="The number of images per page"),
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
@ -232,6 +235,7 @@ async def list_images_with_metadata(
|
|||||||
image_origin,
|
image_origin,
|
||||||
categories,
|
categories,
|
||||||
is_intermediate,
|
is_intermediate,
|
||||||
|
board_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_dtos
|
return image_dtos
|
||||||
|
@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
|
|||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management.models import get_all_model_configs
|
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
|
||||||
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
|
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
|
|||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
|
models: list[MODEL_CONFIGS]
|
||||||
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
@ -72,10 +71,10 @@ class ModelsList(BaseModel):
|
|||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList }},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
base_model: BaseModelType = Query(
|
base_model: Optional[BaseModelType] = Query(
|
||||||
default=None, description="Base model"
|
default=None, description="Base model"
|
||||||
),
|
),
|
||||||
model_type: ModelType = Query(
|
model_type: Optional[ModelType] = Query(
|
||||||
default=None, description="The type of model to get"
|
default=None, description="The type of model to get"
|
||||||
),
|
),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
|
@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config)
|
|||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
|
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import sessions, models, images
|
from .api.routers import sessions, models, images, boards, board_images
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
|
|
||||||
@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api")
|
|||||||
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
app.include_router(images.images_router, prefix="/api")
|
||||||
|
|
||||||
|
app.include_router(boards.boards_router, prefix="/api")
|
||||||
|
|
||||||
|
app.include_router(board_images.board_images_router, prefix="/api")
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
def custom_openapi():
|
def custom_openapi():
|
||||||
@ -116,6 +120,22 @@ def custom_openapi():
|
|||||||
|
|
||||||
invoker_schema["output"] = outputs_ref
|
invoker_schema["output"] = outputs_ref
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models import get_model_config_enums
|
||||||
|
for model_config_format_enum in set(get_model_config_enums()):
|
||||||
|
name = model_config_format_enum.__qualname__
|
||||||
|
|
||||||
|
if name in openapi_schema["components"]["schemas"]:
|
||||||
|
# print(f"Config with name {name} already defined")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
||||||
|
openapi_schema["components"]["schemas"][name] = dict(
|
||||||
|
title=name,
|
||||||
|
description="An enumeration.",
|
||||||
|
type="string",
|
||||||
|
enum=list(v.value for v in model_config_format_enum),
|
||||||
|
)
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
app.openapi_schema = openapi_schema
|
||||||
return app.openapi_schema
|
return app.openapi_schema
|
||||||
|
|
||||||
|
@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
|||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class SD1ModelLoaderInvocation(BaseInvocation):
|
class PipelineModelField(BaseModel):
|
||||||
"""Loading submodels of selected model."""
|
"""Pipeline model field"""
|
||||||
|
|
||||||
type: Literal["sd1_model_loader"] = "sd1_model_loader"
|
model_name: str = Field(description="Name of the model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
|
||||||
|
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads a pipeline model, outputting its submodels."""
|
||||||
|
|
||||||
|
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||||
|
|
||||||
|
model: PipelineModelField = Field(description="The model to load")
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"tags": ["model", "loader"],
|
"tags": ["model", "loader"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model_name": "model" # TODO: rename to model_name?
|
"model": "model"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
base_model = BaseModelType.StableDiffusion1 # TODO:
|
base_model = self.model.base_model
|
||||||
|
model_name = self.model.model_name
|
||||||
|
model_type = ModelType.Pipeline
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation):
|
|||||||
return ModelLoaderOutput(
|
return ModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.UNet,
|
submodel=SubModelType.UNet,
|
||||||
),
|
),
|
||||||
scheduler=ModelInfo(
|
scheduler=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.Scheduler,
|
submodel=SubModelType.Scheduler,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
clip=ClipField(
|
clip=ClipField(
|
||||||
tokenizer=ModelInfo(
|
tokenizer=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.Tokenizer,
|
submodel=SubModelType.Tokenizer,
|
||||||
),
|
),
|
||||||
text_encoder=ModelInfo(
|
text_encoder=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.TextEncoder,
|
submodel=SubModelType.TextEncoder,
|
||||||
),
|
),
|
||||||
loras=[],
|
loras=[],
|
||||||
),
|
),
|
||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
model_name=self.model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=ModelType.Pipeline,
|
model_type=model_type,
|
||||||
submodel=SubModelType.Vae,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: optimize(less code copy)
|
|
||||||
class SD2ModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loading submodels of selected model."""
|
|
||||||
|
|
||||||
type: Literal["sd2_model_loader"] = "sd2_model_loader"
|
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
|
||||||
# TODO: precision?
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["model", "loader"],
|
|
||||||
"type_hints": {
|
|
||||||
"model_name": "model" # TODO: rename to model_name?
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
|
||||||
|
|
||||||
base_model = BaseModelType.StableDiffusion2 # TODO:
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.Tokenizer,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.TextEncoder,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.UNet,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Vae,
|
submodel=SubModelType.Vae,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
254
invokeai/app/services/board_image_record_storage.py
Normal file
254
invokeai/app/services/board_image_record_storage.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from typing import Union, cast
|
||||||
|
from invokeai.app.services.board_record_storage import BoardRecord
|
||||||
|
|
||||||
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.models.image_record import (
|
||||||
|
ImageRecord,
|
||||||
|
deserialize_image_record,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImageRecordStorageBase(ABC):
|
||||||
|
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Adds an image to a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Removes an image from a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_images_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> OffsetPaginatedResults[ImageRecord]:
|
||||||
|
"""Gets images for a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Union[str, None]:
|
||||||
|
"""Gets an image's board id, if it has one."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_image_count_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Gets the number of images for a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||||
|
_filename: str
|
||||||
|
_conn: sqlite3.Connection
|
||||||
|
_cursor: sqlite3.Cursor
|
||||||
|
_lock: threading.Lock
|
||||||
|
|
||||||
|
def __init__(self, filename: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._filename = filename
|
||||||
|
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||||
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
# Enable foreign keys
|
||||||
|
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
self._create_tables()
|
||||||
|
self._conn.commit()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def _create_tables(self) -> None:
|
||||||
|
"""Creates the `board_images` junction table."""
|
||||||
|
|
||||||
|
# Create the `board_images` junction table.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS board_images (
|
||||||
|
board_id TEXT NOT NULL,
|
||||||
|
image_name TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME,
|
||||||
|
-- enforce one-to-many relationship between boards and images using PK
|
||||||
|
-- (we can extend this to many-to-many later)
|
||||||
|
PRIMARY KEY (image_name),
|
||||||
|
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add index for board id
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add index for board id, sorted by created_at
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON board_images FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT INTO board_images (board_id, image_name)
|
||||||
|
VALUES (?, ?)
|
||||||
|
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
|
||||||
|
""",
|
||||||
|
(board_id, image_name, board_id),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM board_images
|
||||||
|
WHERE board_id = ? AND image_name = ?;
|
||||||
|
""",
|
||||||
|
(board_id, image_name),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def get_images_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[ImageRecord]:
|
||||||
|
# TODO: this isn't paginated yet?
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT images.*
|
||||||
|
FROM board_images
|
||||||
|
INNER JOIN images ON board_images.image_name = images.image_name
|
||||||
|
WHERE board_images.board_id = ?
|
||||||
|
ORDER BY board_images.updated_at DESC;
|
||||||
|
""",
|
||||||
|
(board_id,),
|
||||||
|
)
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
count = cast(int, self._cursor.fetchone()[0])
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return OffsetPaginatedResults(
|
||||||
|
items=images, offset=offset, limit=limit, total=count
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Union[str, None]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT board_id
|
||||||
|
FROM board_images
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(image_name,),
|
||||||
|
)
|
||||||
|
result = self._cursor.fetchone()
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
return cast(str, result[0])
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def get_image_count_for_board(self, board_id: str) -> int:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
|
||||||
|
""",
|
||||||
|
(board_id,),
|
||||||
|
)
|
||||||
|
count = cast(int, self._cursor.fetchone()[0])
|
||||||
|
return count
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
142
invokeai/app/services/board_images.py
Normal file
142
invokeai/app/services/board_images.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger
|
||||||
|
from typing import List, Union
|
||||||
|
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||||
|
from invokeai.app.services.board_record_storage import (
|
||||||
|
BoardRecord,
|
||||||
|
BoardRecordStorageBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
from invokeai.app.services.image_record_storage import (
|
||||||
|
ImageRecordStorageBase,
|
||||||
|
OffsetPaginatedResults,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
|
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
|
||||||
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesServiceABC(ABC):
|
||||||
|
"""High-level service for board-image relationship management."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Adds an image to a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Removes an image from a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_images_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
|
"""Gets images for a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Union[str, None]:
|
||||||
|
"""Gets an image's board id, if it has one."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesServiceDependencies:
|
||||||
|
"""Service dependencies for the BoardImagesService."""
|
||||||
|
|
||||||
|
board_image_records: BoardImageRecordStorageBase
|
||||||
|
board_records: BoardRecordStorageBase
|
||||||
|
image_records: ImageRecordStorageBase
|
||||||
|
urls: UrlServiceBase
|
||||||
|
logger: Logger
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
board_image_record_storage: BoardImageRecordStorageBase,
|
||||||
|
image_record_storage: ImageRecordStorageBase,
|
||||||
|
board_record_storage: BoardRecordStorageBase,
|
||||||
|
url: UrlServiceBase,
|
||||||
|
logger: Logger,
|
||||||
|
):
|
||||||
|
self.board_image_records = board_image_record_storage
|
||||||
|
self.image_records = image_record_storage
|
||||||
|
self.board_records = board_record_storage
|
||||||
|
self.urls = url
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesService(BoardImagesServiceABC):
|
||||||
|
_services: BoardImagesServiceDependencies
|
||||||
|
|
||||||
|
def __init__(self, services: BoardImagesServiceDependencies):
|
||||||
|
self._services = services
|
||||||
|
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
self._services.board_image_records.add_image_to_board(board_id, image_name)
|
||||||
|
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
||||||
|
|
||||||
|
def get_images_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
|
image_records = self._services.board_image_records.get_images_for_board(
|
||||||
|
board_id
|
||||||
|
)
|
||||||
|
image_dtos = list(
|
||||||
|
map(
|
||||||
|
lambda r: image_record_to_dto(
|
||||||
|
r,
|
||||||
|
self._services.urls.get_image_url(r.image_name),
|
||||||
|
self._services.urls.get_image_url(r.image_name, True),
|
||||||
|
board_id,
|
||||||
|
),
|
||||||
|
image_records.items,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return OffsetPaginatedResults[ImageDTO](
|
||||||
|
items=image_dtos,
|
||||||
|
offset=image_records.offset,
|
||||||
|
limit=image_records.limit,
|
||||||
|
total=image_records.total,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Union[str, None]:
|
||||||
|
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
||||||
|
return board_id
|
||||||
|
|
||||||
|
|
||||||
|
def board_record_to_dto(
|
||||||
|
board_record: BoardRecord, cover_image_name: str | None, image_count: int
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Converts a board record to a board DTO."""
|
||||||
|
return BoardDTO(
|
||||||
|
**board_record.dict(exclude={'cover_image_name'}),
|
||||||
|
cover_image_name=cover_image_name,
|
||||||
|
image_count=image_count,
|
||||||
|
)
|
329
invokeai/app/services/board_record_storage.py
Normal file
329
invokeai/app/services/board_record_storage.py
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, cast
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from typing import Optional, Union
|
||||||
|
import uuid
|
||||||
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.models.board_record import (
|
||||||
|
BoardRecord,
|
||||||
|
deserialize_board_record,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, Extra
|
||||||
|
|
||||||
|
|
||||||
|
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||||
|
board_name: Optional[str] = Field(description="The board's new name.")
|
||||||
|
cover_image_name: Optional[str] = Field(
|
||||||
|
description="The name of the board's new cover image."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordNotFoundException(Exception):
|
||||||
|
"""Raised when an board record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordSaveException(Exception):
|
||||||
|
"""Raised when an board record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordDeleteException(Exception):
|
||||||
|
"""Raised when an board record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the board record store."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, board_id: str) -> None:
|
||||||
|
"""Deletes a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Saves a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Gets a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Updates a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[BoardRecord]:
|
||||||
|
"""Gets many board records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardRecord]:
|
||||||
|
"""Gets all board records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||||
|
_filename: str
|
||||||
|
_conn: sqlite3.Connection
|
||||||
|
_cursor: sqlite3.Cursor
|
||||||
|
_lock: threading.Lock
|
||||||
|
|
||||||
|
def __init__(self, filename: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._filename = filename
|
||||||
|
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||||
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
# Enable foreign keys
|
||||||
|
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
self._create_tables()
|
||||||
|
self._conn.commit()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def _create_tables(self) -> None:
|
||||||
|
"""Creates the `boards` table and `board_images` junction table."""
|
||||||
|
|
||||||
|
# Create the `boards` table.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS boards (
|
||||||
|
board_id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
board_name TEXT NOT NULL,
|
||||||
|
cover_image_name TEXT,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME,
|
||||||
|
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON boards FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE boards SET updated_at = current_timestamp
|
||||||
|
WHERE board_id = old.board_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, board_id: str) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM boards
|
||||||
|
WHERE board_id = ?;
|
||||||
|
""",
|
||||||
|
(board_id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BoardRecordDeleteException from e
|
||||||
|
except Exception as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BoardRecordDeleteException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
try:
|
||||||
|
board_id = str(uuid.uuid4())
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||||
|
VALUES (?, ?);
|
||||||
|
""",
|
||||||
|
(board_id, board_name),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BoardRecordSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get(board_id)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM boards
|
||||||
|
WHERE board_id = ?;
|
||||||
|
""",
|
||||||
|
(board_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BoardRecordNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BoardRecordNotFoundException
|
||||||
|
return BoardRecord(**dict(result))
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardRecord:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
# Change the name of a board
|
||||||
|
if changes.board_name is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE boards
|
||||||
|
SET board_name = ?
|
||||||
|
WHERE board_id = ?;
|
||||||
|
""",
|
||||||
|
(changes.board_name, board_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Change the cover image of a board
|
||||||
|
if changes.cover_image_name is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE boards
|
||||||
|
SET cover_image_name = ?
|
||||||
|
WHERE board_id = ?;
|
||||||
|
""",
|
||||||
|
(changes.cover_image_name, board_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BoardRecordSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get(board_id)
|
||||||
|
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[BoardRecord]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
# Get all the boards
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM boards
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ? OFFSET ?;
|
||||||
|
""",
|
||||||
|
(limit, offset),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||||
|
|
||||||
|
# Get the total number of boards
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM boards
|
||||||
|
WHERE 1=1;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
count = cast(int, self._cursor.fetchone()[0])
|
||||||
|
|
||||||
|
return OffsetPaginatedResults[BoardRecord](
|
||||||
|
items=boards, offset=offset, limit=limit, total=count
|
||||||
|
)
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardRecord]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
# Get all the boards
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM boards
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||||
|
|
||||||
|
return boards
|
||||||
|
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
185
invokeai/app/services/boards.py
Normal file
185
invokeai/app/services/boards.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from logging import Logger
|
||||||
|
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||||
|
from invokeai.app.services.board_images import board_record_to_dto
|
||||||
|
|
||||||
|
from invokeai.app.services.board_record_storage import (
|
||||||
|
BoardChanges,
|
||||||
|
BoardRecordStorageBase,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.image_record_storage import (
|
||||||
|
ImageRecordStorageBase,
|
||||||
|
OffsetPaginatedResults,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
|
|
||||||
|
|
||||||
|
class BoardServiceABC(ABC):
|
||||||
|
"""High-level service for board management."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Creates a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Gets a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Updates a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
|
"""Gets many boards."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardDTO]:
|
||||||
|
"""Gets all boards."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BoardServiceDependencies:
|
||||||
|
"""Service dependencies for the BoardService."""
|
||||||
|
|
||||||
|
board_image_records: BoardImageRecordStorageBase
|
||||||
|
board_records: BoardRecordStorageBase
|
||||||
|
image_records: ImageRecordStorageBase
|
||||||
|
urls: UrlServiceBase
|
||||||
|
logger: Logger
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
board_image_record_storage: BoardImageRecordStorageBase,
|
||||||
|
image_record_storage: ImageRecordStorageBase,
|
||||||
|
board_record_storage: BoardRecordStorageBase,
|
||||||
|
url: UrlServiceBase,
|
||||||
|
logger: Logger,
|
||||||
|
):
|
||||||
|
self.board_image_records = board_image_record_storage
|
||||||
|
self.image_records = image_record_storage
|
||||||
|
self.board_records = board_record_storage
|
||||||
|
self.urls = url
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
|
||||||
|
class BoardService(BoardServiceABC):
|
||||||
|
_services: BoardServiceDependencies
|
||||||
|
|
||||||
|
def __init__(self, services: BoardServiceDependencies):
|
||||||
|
self._services = services
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
board_record = self._services.board_records.save(board_name)
|
||||||
|
return board_record_to_dto(board_record, None, 0)
|
||||||
|
|
||||||
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
|
board_record = self._services.board_records.get(board_id)
|
||||||
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
|
board_record.board_id
|
||||||
|
)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
|
board_id
|
||||||
|
)
|
||||||
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardDTO:
|
||||||
|
board_record = self._services.board_records.update(board_id, changes)
|
||||||
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
|
board_record.board_id
|
||||||
|
)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
|
board_id
|
||||||
|
)
|
||||||
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
|
def delete(self, board_id: str) -> None:
|
||||||
|
self._services.board_records.delete(board_id)
|
||||||
|
|
||||||
|
def get_many(
|
||||||
|
self, offset: int = 0, limit: int = 10
|
||||||
|
) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
|
board_records = self._services.board_records.get_many(offset, limit)
|
||||||
|
board_dtos = []
|
||||||
|
for r in board_records.items:
|
||||||
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
|
return OffsetPaginatedResults[BoardDTO](
|
||||||
|
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_all(self) -> list[BoardDTO]:
|
||||||
|
board_records = self._services.board_records.get_all()
|
||||||
|
board_dtos = []
|
||||||
|
for r in board_records:
|
||||||
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
|
return board_dtos
|
@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC):
|
|||||||
image_origin: Optional[ResourceOrigin] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
) -> OffsetPaginatedResults[ImageRecord]:
|
) -> OffsetPaginatedResults[ImageRecord]:
|
||||||
"""Gets a page of image records."""
|
"""Gets a page of image records."""
|
||||||
pass
|
pass
|
||||||
@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC):
|
|||||||
"""Saves an image record."""
|
"""Saves an image record."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
|
||||||
|
"""Gets the most recent image for a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||||
_filename: str
|
_filename: str
|
||||||
@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _create_tables(self) -> None:
|
def _create_tables(self) -> None:
|
||||||
"""Creates the tables for the `images` database."""
|
"""Creates the `images` table."""
|
||||||
|
|
||||||
# Create the `images` table.
|
# Create the `images` table.
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
node_id TEXT,
|
node_id TEXT,
|
||||||
metadata TEXT,
|
metadata TEXT,
|
||||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||||
|
board_id TEXT,
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
-- Updated via trigger
|
-- Updated via trigger
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
@ -190,7 +197,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
AFTER UPDATE
|
AFTER UPDATE
|
||||||
ON images FOR EACH ROW
|
ON images FOR EACH ROW
|
||||||
BEGIN
|
BEGIN
|
||||||
UPDATE images SET updated_at = current_timestamp
|
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
WHERE image_name = old.image_name;
|
WHERE image_name = old.image_name;
|
||||||
END;
|
END;
|
||||||
"""
|
"""
|
||||||
@ -259,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
""",
|
""",
|
||||||
(changes.is_intermediate, image_name),
|
(changes.is_intermediate, image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
@ -273,38 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
image_origin: Optional[ResourceOrigin] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
) -> OffsetPaginatedResults[ImageRecord]:
|
) -> OffsetPaginatedResults[ImageRecord]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
|
|
||||||
# Manually build two queries - one for the count, one for the records
|
# Manually build two queries - one for the count, one for the records
|
||||||
|
count_query = """--sql
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM images
|
||||||
|
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||||
|
WHERE 1=1
|
||||||
|
"""
|
||||||
|
|
||||||
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
images_query = """--sql
|
||||||
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
SELECT images.*
|
||||||
|
FROM images
|
||||||
|
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||||
|
WHERE 1=1
|
||||||
|
"""
|
||||||
|
|
||||||
query_conditions = ""
|
query_conditions = ""
|
||||||
query_params = []
|
query_params = []
|
||||||
|
|
||||||
if image_origin is not None:
|
if image_origin is not None:
|
||||||
query_conditions += f"""AND image_origin = ?\n"""
|
query_conditions += """--sql
|
||||||
|
AND images.image_origin = ?
|
||||||
|
"""
|
||||||
query_params.append(image_origin.value)
|
query_params.append(image_origin.value)
|
||||||
|
|
||||||
if categories is not None:
|
if categories is not None:
|
||||||
## Convert the enum values to unique list of strings
|
# Convert the enum values to unique list of strings
|
||||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||||
# Create the correct length of placeholders
|
# Create the correct length of placeholders
|
||||||
placeholders = ",".join("?" * len(category_strings))
|
placeholders = ",".join("?" * len(category_strings))
|
||||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
|
||||||
|
query_conditions += f"""--sql
|
||||||
|
AND images.image_category IN ( {placeholders} )
|
||||||
|
"""
|
||||||
|
|
||||||
# Unpack the included categories into the query params
|
# Unpack the included categories into the query params
|
||||||
for c in category_strings:
|
for c in category_strings:
|
||||||
query_params.append(c)
|
query_params.append(c)
|
||||||
|
|
||||||
if is_intermediate is not None:
|
if is_intermediate is not None:
|
||||||
query_conditions += f"""AND is_intermediate = ?\n"""
|
query_conditions += """--sql
|
||||||
|
AND images.is_intermediate = ?
|
||||||
|
"""
|
||||||
|
|
||||||
query_params.append(is_intermediate)
|
query_params.append(is_intermediate)
|
||||||
|
|
||||||
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
if board_id is not None:
|
||||||
|
query_conditions += """--sql
|
||||||
|
AND board_images.board_id = ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
query_params.append(board_id)
|
||||||
|
|
||||||
|
query_pagination = """--sql
|
||||||
|
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
|
||||||
|
"""
|
||||||
|
|
||||||
# Final images query with pagination
|
# Final images query with pagination
|
||||||
images_query += query_conditions + query_pagination + ";"
|
images_query += query_conditions + query_pagination + ";"
|
||||||
@ -321,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
count_query += query_conditions + ";"
|
count_query += query_conditions + ";"
|
||||||
count_params = query_params.copy()
|
count_params = query_params.copy()
|
||||||
self._cursor.execute(count_query, count_params)
|
self._cursor.execute(count_query, count_params)
|
||||||
count = self._cursor.fetchone()[0]
|
count = cast(int, self._cursor.fetchone()[0])
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
raise e
|
raise e
|
||||||
@ -412,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
raise ImageRecordSaveException from e
|
raise ImageRecordSaveException from e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
|
def get_most_recent_image_for_board(
|
||||||
|
self, board_id: str
|
||||||
|
) -> Union[ImageRecord, None]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT images.*
|
||||||
|
FROM images
|
||||||
|
JOIN board_images ON images.image_name = board_images.image_name
|
||||||
|
WHERE board_images.board_id = ?
|
||||||
|
ORDER BY images.created_at DESC
|
||||||
|
LIMIT 1;
|
||||||
|
""",
|
||||||
|
(board_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return deserialize_image_record(dict(result))
|
||||||
|
@ -10,6 +10,7 @@ from invokeai.app.models.image import (
|
|||||||
InvalidOriginException,
|
InvalidOriginException,
|
||||||
)
|
)
|
||||||
from invokeai.app.models.metadata import ImageMetadata
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||||
from invokeai.app.services.image_record_storage import (
|
from invokeai.app.services.image_record_storage import (
|
||||||
ImageRecordDeleteException,
|
ImageRecordDeleteException,
|
||||||
ImageRecordNotFoundException,
|
ImageRecordNotFoundException,
|
||||||
@ -49,7 +50,7 @@ class ImageServiceABC(ABC):
|
|||||||
image_category: ImageCategory,
|
image_category: ImageCategory,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Creates an image, storing the file and its metadata."""
|
"""Creates an image, storing the file and its metadata."""
|
||||||
pass
|
pass
|
||||||
@ -79,7 +80,7 @@ class ImageServiceABC(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_name: str) -> str:
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
"""Gets an image's path."""
|
"""Gets an image's path."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -101,6 +102,7 @@ class ImageServiceABC(ABC):
|
|||||||
image_origin: Optional[ResourceOrigin] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
"""Gets a paginated list of image DTOs."""
|
"""Gets a paginated list of image DTOs."""
|
||||||
pass
|
pass
|
||||||
@ -114,8 +116,9 @@ class ImageServiceABC(ABC):
|
|||||||
class ImageServiceDependencies:
|
class ImageServiceDependencies:
|
||||||
"""Service dependencies for the ImageService."""
|
"""Service dependencies for the ImageService."""
|
||||||
|
|
||||||
records: ImageRecordStorageBase
|
image_records: ImageRecordStorageBase
|
||||||
files: ImageFileStorageBase
|
image_files: ImageFileStorageBase
|
||||||
|
board_image_records: BoardImageRecordStorageBase
|
||||||
metadata: MetadataServiceBase
|
metadata: MetadataServiceBase
|
||||||
urls: UrlServiceBase
|
urls: UrlServiceBase
|
||||||
logger: Logger
|
logger: Logger
|
||||||
@ -126,14 +129,16 @@ class ImageServiceDependencies:
|
|||||||
self,
|
self,
|
||||||
image_record_storage: ImageRecordStorageBase,
|
image_record_storage: ImageRecordStorageBase,
|
||||||
image_file_storage: ImageFileStorageBase,
|
image_file_storage: ImageFileStorageBase,
|
||||||
|
board_image_record_storage: BoardImageRecordStorageBase,
|
||||||
metadata: MetadataServiceBase,
|
metadata: MetadataServiceBase,
|
||||||
url: UrlServiceBase,
|
url: UrlServiceBase,
|
||||||
logger: Logger,
|
logger: Logger,
|
||||||
names: NameServiceBase,
|
names: NameServiceBase,
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
):
|
):
|
||||||
self.records = image_record_storage
|
self.image_records = image_record_storage
|
||||||
self.files = image_file_storage
|
self.image_files = image_file_storage
|
||||||
|
self.board_image_records = board_image_record_storage
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.urls = url
|
self.urls = url
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
@ -144,25 +149,8 @@ class ImageServiceDependencies:
|
|||||||
class ImageService(ImageServiceABC):
|
class ImageService(ImageServiceABC):
|
||||||
_services: ImageServiceDependencies
|
_services: ImageServiceDependencies
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, services: ImageServiceDependencies):
|
||||||
self,
|
self._services = services
|
||||||
image_record_storage: ImageRecordStorageBase,
|
|
||||||
image_file_storage: ImageFileStorageBase,
|
|
||||||
metadata: MetadataServiceBase,
|
|
||||||
url: UrlServiceBase,
|
|
||||||
logger: Logger,
|
|
||||||
names: NameServiceBase,
|
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
|
||||||
):
|
|
||||||
self._services = ImageServiceDependencies(
|
|
||||||
image_record_storage=image_record_storage,
|
|
||||||
image_file_storage=image_file_storage,
|
|
||||||
metadata=metadata,
|
|
||||||
url=url,
|
|
||||||
logger=logger,
|
|
||||||
names=names,
|
|
||||||
graph_execution_manager=graph_execution_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create(
|
def create(
|
||||||
self,
|
self,
|
||||||
@ -187,7 +175,7 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
created_at = self._services.records.save(
|
self._services.image_records.save(
|
||||||
# Non-nullable fields
|
# Non-nullable fields
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_origin=image_origin,
|
image_origin=image_origin,
|
||||||
@ -202,35 +190,15 @@ class ImageService(ImageServiceABC):
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._services.files.save(
|
self._services.image_files.save(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image=image,
|
image=image,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_url = self._services.urls.get_image_url(image_name)
|
image_dto = self.get_dto(image_name)
|
||||||
thumbnail_url = self._services.urls.get_image_url(image_name, True)
|
|
||||||
|
|
||||||
return ImageDTO(
|
return image_dto
|
||||||
# Non-nullable fields
|
|
||||||
image_name=image_name,
|
|
||||||
image_origin=image_origin,
|
|
||||||
image_category=image_category,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
# Nullable fields
|
|
||||||
node_id=node_id,
|
|
||||||
session_id=session_id,
|
|
||||||
metadata=metadata,
|
|
||||||
# Meta fields
|
|
||||||
created_at=created_at,
|
|
||||||
updated_at=created_at, # this is always the same as the created_at at this time
|
|
||||||
deleted_at=None,
|
|
||||||
is_intermediate=is_intermediate,
|
|
||||||
# Extra non-nullable fields for DTO
|
|
||||||
image_url=image_url,
|
|
||||||
thumbnail_url=thumbnail_url,
|
|
||||||
)
|
|
||||||
except ImageRecordSaveException:
|
except ImageRecordSaveException:
|
||||||
self._services.logger.error("Failed to save image record")
|
self._services.logger.error("Failed to save image record")
|
||||||
raise
|
raise
|
||||||
@ -247,7 +215,7 @@ class ImageService(ImageServiceABC):
|
|||||||
changes: ImageRecordChanges,
|
changes: ImageRecordChanges,
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
self._services.records.update(image_name, changes)
|
self._services.image_records.update(image_name, changes)
|
||||||
return self.get_dto(image_name)
|
return self.get_dto(image_name)
|
||||||
except ImageRecordSaveException:
|
except ImageRecordSaveException:
|
||||||
self._services.logger.error("Failed to update image record")
|
self._services.logger.error("Failed to update image record")
|
||||||
@ -258,7 +226,7 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
return self._services.files.get(image_name)
|
return self._services.image_files.get(image_name)
|
||||||
except ImageFileNotFoundException:
|
except ImageFileNotFoundException:
|
||||||
self._services.logger.error("Failed to get image file")
|
self._services.logger.error("Failed to get image file")
|
||||||
raise
|
raise
|
||||||
@ -268,7 +236,7 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def get_record(self, image_name: str) -> ImageRecord:
|
def get_record(self, image_name: str) -> ImageRecord:
|
||||||
try:
|
try:
|
||||||
return self._services.records.get(image_name)
|
return self._services.image_records.get(image_name)
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self._services.logger.error("Image record not found")
|
||||||
raise
|
raise
|
||||||
@ -278,12 +246,13 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def get_dto(self, image_name: str) -> ImageDTO:
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
try:
|
try:
|
||||||
image_record = self._services.records.get(image_name)
|
image_record = self._services.image_records.get(image_name)
|
||||||
|
|
||||||
image_dto = image_record_to_dto(
|
image_dto = image_record_to_dto(
|
||||||
image_record,
|
image_record,
|
||||||
self._services.urls.get_image_url(image_name),
|
self._services.urls.get_image_url(image_name),
|
||||||
self._services.urls.get_image_url(image_name, True),
|
self._services.urls.get_image_url(image_name, True),
|
||||||
|
self._services.board_image_records.get_board_for_image(image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
@ -296,14 +265,14 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
try:
|
try:
|
||||||
return self._services.files.get_path(image_name, thumbnail)
|
return self._services.image_files.get_path(image_name, thumbnail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem getting image path")
|
self._services.logger.error("Problem getting image path")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def validate_path(self, path: str) -> bool:
|
def validate_path(self, path: str) -> bool:
|
||||||
try:
|
try:
|
||||||
return self._services.files.validate_path(path)
|
return self._services.image_files.validate_path(path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._services.logger.error("Problem validating image path")
|
self._services.logger.error("Problem validating image path")
|
||||||
raise e
|
raise e
|
||||||
@ -322,14 +291,16 @@ class ImageService(ImageServiceABC):
|
|||||||
image_origin: Optional[ResourceOrigin] = None,
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
is_intermediate: Optional[bool] = None,
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
try:
|
try:
|
||||||
results = self._services.records.get_many(
|
results = self._services.image_records.get_many(
|
||||||
offset,
|
offset,
|
||||||
limit,
|
limit,
|
||||||
image_origin,
|
image_origin,
|
||||||
categories,
|
categories,
|
||||||
is_intermediate,
|
is_intermediate,
|
||||||
|
board_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_dtos = list(
|
image_dtos = list(
|
||||||
@ -338,6 +309,9 @@ class ImageService(ImageServiceABC):
|
|||||||
r,
|
r,
|
||||||
self._services.urls.get_image_url(r.image_name),
|
self._services.urls.get_image_url(r.image_name),
|
||||||
self._services.urls.get_image_url(r.image_name, True),
|
self._services.urls.get_image_url(r.image_name, True),
|
||||||
|
self._services.board_image_records.get_board_for_image(
|
||||||
|
r.image_name
|
||||||
|
),
|
||||||
),
|
),
|
||||||
results.items,
|
results.items,
|
||||||
)
|
)
|
||||||
@ -355,8 +329,8 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def delete(self, image_name: str):
|
def delete(self, image_name: str):
|
||||||
try:
|
try:
|
||||||
self._services.files.delete(image_name)
|
self._services.image_files.delete(image_name)
|
||||||
self._services.records.delete(image_name)
|
self._services.image_records.delete(image_name)
|
||||||
except ImageRecordDeleteException:
|
except ImageRecordDeleteException:
|
||||||
self._services.logger.error(f"Failed to delete image record")
|
self._services.logger.error(f"Failed to delete image record")
|
||||||
raise
|
raise
|
||||||
|
@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from invokeai.app.services.images import ImageService
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
from invokeai.backend import ModelManager
|
from invokeai.backend import ModelManager
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
@ -26,9 +28,9 @@ class InvocationServices:
|
|||||||
model_manager: "ModelManager"
|
model_manager: "ModelManager"
|
||||||
restoration: "RestorationServices"
|
restoration: "RestorationServices"
|
||||||
configuration: "InvokeAISettings"
|
configuration: "InvokeAISettings"
|
||||||
images: "ImageService"
|
images: "ImageServiceABC"
|
||||||
|
boards: "BoardServiceABC"
|
||||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
board_images: "BoardImagesServiceABC"
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
@ -39,7 +41,9 @@ class InvocationServices:
|
|||||||
events: "EventServiceBase",
|
events: "EventServiceBase",
|
||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
latents: "LatentsStorageBase",
|
latents: "LatentsStorageBase",
|
||||||
images: "ImageService",
|
images: "ImageServiceABC",
|
||||||
|
boards: "BoardServiceABC",
|
||||||
|
board_images: "BoardImagesServiceABC",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||||
@ -52,9 +56,12 @@ class InvocationServices:
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.images = images
|
self.images = images
|
||||||
|
self.boards = boards
|
||||||
|
self.board_images = board_images
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.graph_library = graph_library
|
self.graph_library = graph_library
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.restoration = restoration
|
self.restoration = restoration
|
||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
|
self.boards = boards
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
|
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
from invokeai.backend.model_management.model_manager import (
|
||||||
@ -69,19 +69,6 @@ class ModelManagerServiceBase(ABC):
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns the name and typeof the default model, or None
|
|
||||||
if none is defined.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
"""Sets the default model to the indicated name."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
@ -270,17 +257,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_type,
|
model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns the name of the default model, or None
|
|
||||||
if none is defined.
|
|
||||||
"""
|
|
||||||
return self.mgr.default_model()
|
|
||||||
|
|
||||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
|
||||||
"""Sets the default model to the indicated name."""
|
|
||||||
self.mgr.set_default_model(model_name, base_model, model_type)
|
|
||||||
|
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
"""
|
"""
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
@ -297,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self,
|
self,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None
|
model_type: Optional[ModelType] = None
|
||||||
) -> dict:
|
) -> list[dict]:
|
||||||
|
# ) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a list of models.
|
||||||
{ model_type1:
|
|
||||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
|
||||||
'model_name' : name,
|
|
||||||
'model_type' : SDModelType,
|
|
||||||
'description': description,
|
|
||||||
'format': 'folder'|'safetensors'|'ckpt'
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
},
|
|
||||||
model_type2:
|
|
||||||
{ model_name_n: etc
|
|
||||||
}
|
|
||||||
"""
|
"""
|
||||||
return self.mgr.list_models(base_model, model_type)
|
return self.mgr.list_models(base_model, model_type)
|
||||||
|
|
||||||
|
62
invokeai/app/services/models/board_record.py
Normal file
62
invokeai/app/services/models/board_record.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from typing import Optional, Union
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||||
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecord(BaseModel):
|
||||||
|
"""Deserialized board record."""
|
||||||
|
|
||||||
|
board_id: str = Field(description="The unique ID of the board.")
|
||||||
|
"""The unique ID of the board."""
|
||||||
|
board_name: str = Field(description="The name of the board.")
|
||||||
|
"""The name of the board."""
|
||||||
|
created_at: Union[datetime, str] = Field(
|
||||||
|
description="The created timestamp of the board."
|
||||||
|
)
|
||||||
|
"""The created timestamp of the image."""
|
||||||
|
updated_at: Union[datetime, str] = Field(
|
||||||
|
description="The updated timestamp of the board."
|
||||||
|
)
|
||||||
|
"""The updated timestamp of the image."""
|
||||||
|
deleted_at: Union[datetime, str, None] = Field(
|
||||||
|
description="The deleted timestamp of the board."
|
||||||
|
)
|
||||||
|
"""The updated timestamp of the image."""
|
||||||
|
cover_image_name: Optional[str] = Field(
|
||||||
|
description="The name of the cover image of the board."
|
||||||
|
)
|
||||||
|
"""The name of the cover image of the board."""
|
||||||
|
|
||||||
|
|
||||||
|
class BoardDTO(BoardRecord):
|
||||||
|
"""Deserialized board record with cover image URL and image count."""
|
||||||
|
|
||||||
|
cover_image_name: Optional[str] = Field(
|
||||||
|
description="The name of the board's cover image."
|
||||||
|
)
|
||||||
|
"""The URL of the thumbnail of the most recent image in the board."""
|
||||||
|
image_count: int = Field(description="The number of images in the board.")
|
||||||
|
"""The number of images in the board."""
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||||
|
"""Deserializes a board record."""
|
||||||
|
|
||||||
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
|
board_id = board_dict.get("board_id", "unknown")
|
||||||
|
board_name = board_dict.get("board_name", "unknown")
|
||||||
|
cover_image_name = board_dict.get("cover_image_name", "unknown")
|
||||||
|
created_at = board_dict.get("created_at", get_iso_timestamp())
|
||||||
|
updated_at = board_dict.get("updated_at", get_iso_timestamp())
|
||||||
|
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
|
||||||
|
|
||||||
|
return BoardRecord(
|
||||||
|
board_id=board_id,
|
||||||
|
board_name=board_name,
|
||||||
|
cover_image_name=cover_image_name,
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
deleted_at=deleted_at,
|
||||||
|
)
|
@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||||
"""Deserialized image record, enriched for the frontend with URLs."""
|
"""Deserialized image record, enriched for the frontend."""
|
||||||
|
|
||||||
|
board_id: Union[str, None] = Field(
|
||||||
|
description="The id of the board the image belongs to, if one exists."
|
||||||
|
)
|
||||||
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def image_record_to_dto(
|
def image_record_to_dto(
|
||||||
image_record: ImageRecord, image_url: str, thumbnail_url: str
|
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Converts an image record to an image DTO."""
|
"""Converts an image record to an image DTO."""
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
**image_record.dict(),
|
**image_record.dict(),
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
thumbnail_url=thumbnail_url,
|
thumbnail_url=thumbnail_url,
|
||||||
|
board_id=board_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -266,6 +266,8 @@ class ModelManager(object):
|
|||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
|
# alias for config file
|
||||||
|
model_config["model_format"] = model_config.pop("format")
|
||||||
self.models[model_key] = model_class.create_config(**model_config)
|
self.models[model_key] = model_class.create_config(**model_config)
|
||||||
|
|
||||||
# check config version number and update on disk/RAM if necessary
|
# check config version number and update on disk/RAM if necessary
|
||||||
@ -445,38 +447,6 @@ class ModelManager(object):
|
|||||||
_cache = self.cache,
|
_cache = self.cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns the name of the default model, or None
|
|
||||||
if none is defined.
|
|
||||||
"""
|
|
||||||
for model_key, model_config in self.models.items():
|
|
||||||
if model_config.default:
|
|
||||||
return self.parse_key(model_key)
|
|
||||||
|
|
||||||
for model_key, _ in self.models.items():
|
|
||||||
return self.parse_key(model_key)
|
|
||||||
else:
|
|
||||||
return None # TODO: or redo as (None, None, None)
|
|
||||||
|
|
||||||
def set_default_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Set the default model. The change will not take
|
|
||||||
effect until you call model_manager.commit()
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_key = self.model_key(model_name, base_model, model_type)
|
|
||||||
if model_key not in self.models:
|
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
|
||||||
|
|
||||||
for cur_model_key, config in self.models.items():
|
|
||||||
config.default = cur_model_key == model_key
|
|
||||||
|
|
||||||
def model_info(
|
def model_info(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
@ -503,9 +473,9 @@ class ModelManager(object):
|
|||||||
self,
|
self,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Dict[str, Dict[str, str]]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Return a dict of models, in format [base_model][model_type][model_name]
|
Return a list of models.
|
||||||
|
|
||||||
Please use model_manager.models() to get all the model names,
|
Please use model_manager.models() to get all the model names,
|
||||||
model_manager.model_info('model-name') to get the stanza for the model
|
model_manager.model_info('model-name') to get the stanza for the model
|
||||||
@ -513,7 +483,7 @@ class ModelManager(object):
|
|||||||
object derived from models.yaml
|
object derived from models.yaml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
models = dict()
|
models = []
|
||||||
for model_key in sorted(self.models, key=str.casefold):
|
for model_key in sorted(self.models, key=str.casefold):
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
|
|
||||||
@ -523,18 +493,16 @@ class ModelManager(object):
|
|||||||
if model_type is not None and cur_model_type != model_type:
|
if model_type is not None and cur_model_type != model_type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cur_base_model not in models:
|
model_dict = dict(
|
||||||
models[cur_base_model] = dict()
|
|
||||||
if cur_model_type not in models[cur_base_model]:
|
|
||||||
models[cur_base_model][cur_model_type] = dict()
|
|
||||||
|
|
||||||
models[cur_base_model][cur_model_type][cur_model_name] = dict(
|
|
||||||
**model_config.dict(exclude_defaults=True),
|
**model_config.dict(exclude_defaults=True),
|
||||||
|
# OpenAPIModelInfoBase
|
||||||
name=cur_model_name,
|
name=cur_model_name,
|
||||||
base_model=cur_base_model,
|
base_model=cur_base_model,
|
||||||
type=cur_model_type,
|
type=cur_model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
models.append(model_dict)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def print_models(self) -> None:
|
def print_models(self) -> None:
|
||||||
@ -646,7 +614,9 @@ class ModelManager(object):
|
|||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
# TODO: or exclude_unset better fits here?
|
# TODO: or exclude_unset better fits here?
|
||||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
|
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||||
|
# alias for config file
|
||||||
|
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
||||||
|
|
||||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
import inspect
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Literal, get_origin
|
||||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||||
from .vae import VaeModel
|
from .vae import VaeModel
|
||||||
@ -29,10 +33,63 @@ MODEL_CLASSES = {
|
|||||||
#},
|
#},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_all_model_configs():
|
MODEL_CONFIGS = list()
|
||||||
configs = set()
|
OPENAPI_MODEL_CONFIGS = list()
|
||||||
for models in MODEL_CLASSES.values():
|
|
||||||
for _, model in models.items():
|
class OpenAPIModelInfoBase(BaseModel):
|
||||||
configs.update(model._get_configs().values())
|
name: str
|
||||||
configs.discard(None)
|
base_model: BaseModelType
|
||||||
return list(configs) # TODO: set, list or tuple
|
type: ModelType
|
||||||
|
|
||||||
|
|
||||||
|
for base_model, models in MODEL_CLASSES.items():
|
||||||
|
for model_type, model_class in models.items():
|
||||||
|
model_configs = set(model_class._get_configs().values())
|
||||||
|
model_configs.discard(None)
|
||||||
|
MODEL_CONFIGS.extend(model_configs)
|
||||||
|
|
||||||
|
for cfg in model_configs:
|
||||||
|
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||||
|
openapi_cfg_name = model_name + cfg_name
|
||||||
|
if openapi_cfg_name in vars():
|
||||||
|
continue
|
||||||
|
|
||||||
|
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||||
|
__annotations__ = dict(
|
||||||
|
type=Literal[model_type.value],
|
||||||
|
),
|
||||||
|
))
|
||||||
|
|
||||||
|
#globals()[openapi_cfg_name] = api_wrapper
|
||||||
|
vars()[openapi_cfg_name] = api_wrapper
|
||||||
|
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||||
|
|
||||||
|
def get_model_config_enums():
|
||||||
|
enums = list()
|
||||||
|
|
||||||
|
for model_config in MODEL_CONFIGS:
|
||||||
|
fields = inspect.get_annotations(model_config)
|
||||||
|
try:
|
||||||
|
field = fields["model_format"]
|
||||||
|
except:
|
||||||
|
raise Exception("format field not found")
|
||||||
|
|
||||||
|
# model_format: None
|
||||||
|
# model_format: SomeModelFormat
|
||||||
|
# model_format: Literal[SomeModelFormat.Diffusers]
|
||||||
|
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
|
||||||
|
|
||||||
|
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||||
|
enums.append(field)
|
||||||
|
|
||||||
|
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||||
|
enums.append(type(field.__args__[0]))
|
||||||
|
|
||||||
|
elif field is None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
||||||
|
|
||||||
|
return enums
|
||||||
|
|
||||||
|
@ -48,12 +48,10 @@ class ModelError(str, Enum):
|
|||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
path: str # or Path
|
path: str # or Path
|
||||||
#name: str # not included as present in model key
|
|
||||||
description: Optional[str] = Field(None)
|
description: Optional[str] = Field(None)
|
||||||
format: Optional[str] = Field(None)
|
model_format: Optional[str] = Field(None)
|
||||||
default: Optional[bool] = Field(False)
|
|
||||||
# do not save to config
|
# do not save to config
|
||||||
error: Optional[ModelError] = Field(None, exclude=True)
|
error: Optional[ModelError] = Field(None)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
use_enum_values = True
|
use_enum_values = True
|
||||||
@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||||
if len(subtypes) < 2:
|
if len(subtypes) < 2:
|
||||||
raise Exception("Invalid subfolder definition!")
|
raise Exception("Invalid subfolder definition!")
|
||||||
|
if all(t is None for t in subtypes):
|
||||||
|
return None
|
||||||
|
elif any(t is None for t in subtypes):
|
||||||
|
raise Exception(f"Unsupported definition: {subtypes}")
|
||||||
|
|
||||||
if subtypes[0] in ["diffusers", "transformers"]:
|
if subtypes[0] in ["diffusers", "transformers"]:
|
||||||
res_type = sys.modules[subtypes[0]]
|
res_type = sys.modules[subtypes[0]]
|
||||||
subtypes = subtypes[1:]
|
subtypes = subtypes[1:]
|
||||||
@ -122,47 +125,41 @@ class ModelBase(metaclass=ABCMeta):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
fields = inspect.get_annotations(value)
|
fields = inspect.get_annotations(value)
|
||||||
if "format" not in fields:
|
try:
|
||||||
raise Exception("Invalid config definition - format field not found")
|
field = fields["model_format"]
|
||||||
|
except:
|
||||||
|
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||||
|
|
||||||
format_type = typing.get_origin(fields["format"])
|
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||||
if format_type not in {None, Literal, Union}:
|
for model_format in field:
|
||||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
configs[model_format.value] = value
|
||||||
|
|
||||||
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
|
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
for model_format in field.__args__:
|
||||||
|
configs[model_format.value] = value
|
||||||
|
|
||||||
|
elif field is None:
|
||||||
|
configs[None] = value
|
||||||
|
|
||||||
if format_type == Union:
|
|
||||||
f_fields = fields["format"].__args__
|
|
||||||
else:
|
else:
|
||||||
f_fields = (fields["format"],)
|
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
|
||||||
|
|
||||||
|
|
||||||
for field in f_fields:
|
|
||||||
if field is None:
|
|
||||||
format_name = None
|
|
||||||
else:
|
|
||||||
format_name = field.__args__[0]
|
|
||||||
|
|
||||||
configs[format_name] = value # TODO: error when override(multiple)?
|
|
||||||
|
|
||||||
|
|
||||||
cls.__configs = configs
|
cls.__configs = configs
|
||||||
return cls.__configs
|
return cls.__configs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||||
if "format" not in kwargs:
|
if "model_format" not in kwargs:
|
||||||
raise Exception("Field 'format' not found in model config")
|
raise Exception("Field 'model_format' not found in model config")
|
||||||
|
|
||||||
configs = cls._get_configs()
|
configs = cls._get_configs()
|
||||||
return configs[kwargs["format"]](**kwargs)
|
return configs[kwargs["model_format"]](**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=cls.detect_format(path),
|
model_format=cls.detect_format(path),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -14,12 +15,16 @@ from .base import (
|
|||||||
classproperty,
|
classproperty,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class ControlNetModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class ControlNetModel(ModelBase):
|
class ControlNetModel(ModelBase):
|
||||||
#model_class: Type
|
#model_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
model_format: ControlNetModelFormat
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.ControlNet
|
assert model_type == ModelType.ControlNet
|
||||||
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return "diffusers"
|
return ControlNetModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return ControlNetModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
|
|||||||
config: ModelConfigBase, # empty config or config of parent model
|
config: ModelConfigBase, # empty config or config of parent model
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) != "diffusers":
|
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
|
||||||
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
|
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
@ -12,11 +13,15 @@ from .base import (
|
|||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import LoRAModel as LoRAModelRaw
|
from ..lora import LoRAModel as LoRAModelRaw
|
||||||
|
|
||||||
|
class LoRAModelFormat(str, Enum):
|
||||||
|
LyCORIS = "lycoris"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class LoRAModel(ModelBase):
|
class LoRAModel(ModelBase):
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
model_format: LoRAModelFormat # TODO:
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Lora
|
assert model_type == ModelType.Lora
|
||||||
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return "diffusers"
|
return LoRAModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "lycoris"
|
return LoRAModelFormat.LyCORIS
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
|
|||||||
config: ModelConfigBase,
|
config: ModelConfigBase,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) == "diffusers":
|
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||||
# TODO: add diffusers lora when it stabilizes a bit
|
# TODO: add diffusers lora when it stabilizes a bit
|
||||||
raise NotImplementedError("Diffusers lora not supported")
|
raise NotImplementedError("Diffusers lora not supported")
|
||||||
else:
|
else:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
from enum import Enum
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
@ -19,16 +20,19 @@ from .base import (
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
class StableDiffusion1ModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class StableDiffusion1Model(DiffusersModel):
|
class StableDiffusion1Model(DiffusersModel):
|
||||||
|
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
def probe_config(cls, path: str, **kwargs):
|
def probe_config(cls, path: str, **kwargs):
|
||||||
model_format = cls.detect_format(path)
|
model_format = cls.detect_format(path)
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
ckpt_config_path = kwargs.get("config", None)
|
||||||
if model_format == "checkpoint":
|
if model_format == StableDiffusion1ModelFormat.Checkpoint:
|
||||||
if ckpt_config_path:
|
if ckpt_config_path:
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
elif model_format == "diffusers":
|
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
with open(unet_config_path, "r") as f:
|
with open(unet_config_path, "r") as f:
|
||||||
@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=model_format,
|
model_format=model_format,
|
||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, model_path: str):
|
def detect_format(cls, model_path: str):
|
||||||
if os.path.isdir(model_path):
|
if os.path.isdir(model_path):
|
||||||
return "diffusers"
|
return StableDiffusion1ModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return StableDiffusion1ModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
class StableDiffusion2ModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class StableDiffusion2Model(DiffusersModel):
|
class StableDiffusion2Model(DiffusersModel):
|
||||||
|
|
||||||
# TODO: check that configs overwriten properly
|
# TODO: check that configs overwriten properly
|
||||||
class DiffusersConfig(ModelConfigBase):
|
class DiffusersConfig(ModelConfigBase):
|
||||||
format: Literal["diffusers"]
|
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
|
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
format: Literal["checkpoint"]
|
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
||||||
vae: Optional[str] = Field(None)
|
vae: Optional[str] = Field(None)
|
||||||
config: Optional[str] = Field(None)
|
config: Optional[str] = Field(None)
|
||||||
variant: ModelVariantType
|
variant: ModelVariantType
|
||||||
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
def probe_config(cls, path: str, **kwargs):
|
def probe_config(cls, path: str, **kwargs):
|
||||||
model_format = cls.detect_format(path)
|
model_format = cls.detect_format(path)
|
||||||
ckpt_config_path = kwargs.get("config", None)
|
ckpt_config_path = kwargs.get("config", None)
|
||||||
if model_format == "checkpoint":
|
if model_format == StableDiffusion2ModelFormat.Checkpoint:
|
||||||
if ckpt_config_path:
|
if ckpt_config_path:
|
||||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||||
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||||
|
|
||||||
elif model_format == "diffusers":
|
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
||||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||||
if os.path.exists(unet_config_path):
|
if os.path.exists(unet_config_path):
|
||||||
with open(unet_config_path, "r") as f:
|
with open(unet_config_path, "r") as f:
|
||||||
@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
|
|
||||||
return cls.create_config(
|
return cls.create_config(
|
||||||
path=path,
|
path=path,
|
||||||
format=model_format,
|
model_format=model_format,
|
||||||
|
|
||||||
config=ckpt_config_path,
|
config=ckpt_config_path,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, model_path: str):
|
def detect_format(cls, model_path: str):
|
||||||
if os.path.isdir(model_path):
|
if os.path.isdir(model_path):
|
||||||
return "diffusers"
|
return StableDiffusion2ModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return StableDiffusion2ModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -281,8 +288,8 @@ def _convert_ckpt_and_cache(
|
|||||||
prediction_type = SchedulerPredictionType.Epsilon
|
prediction_type = SchedulerPredictionType.Epsilon
|
||||||
|
|
||||||
elif version == BaseModelType.StableDiffusion2:
|
elif version == BaseModelType.StableDiffusion2:
|
||||||
upcast_attention = config.upcast_attention
|
upcast_attention = model_config.upcast_attention
|
||||||
prediction_type = config.prediction_type
|
prediction_type = model_config.prediction_type
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown model provided: {version}")
|
raise Exception(f"Unknown model provided: {version}")
|
||||||
|
@ -16,7 +16,7 @@ class TextualInversionModel(ModelBase):
|
|||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: None
|
model_format: None
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.TextualInversion
|
assert model_type == ModelType.TextualInversion
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
import safetensors
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Union, Literal
|
||||||
from .base import (
|
from .base import (
|
||||||
@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from diffusers.utils import is_safetensors_available
|
from diffusers.utils import is_safetensors_available
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
class VaeModelFormat(str, Enum):
|
||||||
|
Checkpoint = "checkpoint"
|
||||||
|
Diffusers = "diffusers"
|
||||||
|
|
||||||
class VaeModel(ModelBase):
|
class VaeModel(ModelBase):
|
||||||
#vae_class: Type
|
#vae_class: Type
|
||||||
#model_size: int
|
#model_size: int
|
||||||
|
|
||||||
class Config(ModelConfigBase):
|
class Config(ModelConfigBase):
|
||||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
model_format: VaeModelFormat
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.Vae
|
assert model_type == ModelType.Vae
|
||||||
@ -70,9 +76,9 @@ class VaeModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
return "diffusers"
|
return VaeModelFormat.Diffusers
|
||||||
else:
|
else:
|
||||||
return "checkpoint"
|
return VaeModelFormat.Checkpoint
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -82,7 +88,7 @@ class VaeModel(ModelBase):
|
|||||||
config: ModelConfigBase, # empty config or config of parent model
|
config: ModelConfigBase, # empty config or config of parent model
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
if cls.detect_format(model_path) != "diffusers":
|
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||||
return _convert_vae_ckpt_and_cache(
|
return _convert_vae_ckpt_and_cache(
|
||||||
weights_path=model_path,
|
weights_path=model_path,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
|
@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys';
|
|||||||
import Toaster from './Toaster';
|
import Toaster from './Toaster';
|
||||||
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
|
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||||
|
import { useListModelsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
const DEFAULT_CONFIG = {};
|
const DEFAULT_CONFIG = {};
|
||||||
|
|
||||||
@ -45,6 +47,18 @@ const App = ({
|
|||||||
|
|
||||||
const isApplicationReady = useIsApplicationReady();
|
const isApplicationReady = useIsApplicationReady();
|
||||||
|
|
||||||
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
const { data: controlnetModels } = useListModelsQuery({
|
||||||
|
model_type: 'controlnet',
|
||||||
|
});
|
||||||
|
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
|
||||||
|
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
|
||||||
|
const { data: embeddingModels } = useListModelsQuery({
|
||||||
|
model_type: 'embedding',
|
||||||
|
});
|
||||||
|
|
||||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
@ -143,6 +157,7 @@ const App = ({
|
|||||||
</Portal>
|
</Portal>
|
||||||
</Grid>
|
</Grid>
|
||||||
<DeleteImageModal />
|
<DeleteImageModal />
|
||||||
|
<UpdateImageBoardModal />
|
||||||
<Toaster />
|
<Toaster />
|
||||||
<GlobalHotkeys />
|
<GlobalHotkeys />
|
||||||
</>
|
</>
|
||||||
|
@ -21,6 +21,8 @@ import {
|
|||||||
DeleteImageContext,
|
DeleteImageContext,
|
||||||
DeleteImageContextProvider,
|
DeleteImageContextProvider,
|
||||||
} from 'app/contexts/DeleteImageContext';
|
} from 'app/contexts/DeleteImageContext';
|
||||||
|
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||||
|
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
|
||||||
|
|
||||||
const App = lazy(() => import('./App'));
|
const App = lazy(() => import('./App'));
|
||||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||||
@ -76,11 +78,13 @@ const InvokeAIUI = ({
|
|||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<ImageDndContext>
|
<ImageDndContext>
|
||||||
<DeleteImageContextProvider>
|
<DeleteImageContextProvider>
|
||||||
<App
|
<AddImageToBoardContextProvider>
|
||||||
config={config}
|
<App
|
||||||
headerComponent={headerComponent}
|
config={config}
|
||||||
setIsReady={setIsReady}
|
headerComponent={headerComponent}
|
||||||
/>
|
setIsReady={setIsReady}
|
||||||
|
/>
|
||||||
|
</AddImageToBoardContextProvider>
|
||||||
</DeleteImageContextProvider>
|
</DeleteImageContextProvider>
|
||||||
</ImageDndContext>
|
</ImageDndContext>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
|
@ -0,0 +1,89 @@
|
|||||||
|
import { useDisclosure } from '@chakra-ui/react';
|
||||||
|
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { useAddImageToBoardMutation } from 'services/apiSlice';
|
||||||
|
|
||||||
|
export type ImageUsage = {
|
||||||
|
isInitialImage: boolean;
|
||||||
|
isCanvasImage: boolean;
|
||||||
|
isNodesImage: boolean;
|
||||||
|
isControlNetImage: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
type AddImageToBoardContextValue = {
|
||||||
|
/**
|
||||||
|
* Whether the move image dialog is open.
|
||||||
|
*/
|
||||||
|
isOpen: boolean;
|
||||||
|
/**
|
||||||
|
* Closes the move image dialog.
|
||||||
|
*/
|
||||||
|
onClose: () => void;
|
||||||
|
/**
|
||||||
|
* The image pending movement
|
||||||
|
*/
|
||||||
|
image?: ImageDTO;
|
||||||
|
onClickAddToBoard: (image: ImageDTO) => void;
|
||||||
|
handleAddToBoard: (boardId: string) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const AddImageToBoardContext =
|
||||||
|
createContext<AddImageToBoardContextValue>({
|
||||||
|
isOpen: false,
|
||||||
|
onClose: () => undefined,
|
||||||
|
onClickAddToBoard: () => undefined,
|
||||||
|
handleAddToBoard: () => undefined,
|
||||||
|
});
|
||||||
|
|
||||||
|
type Props = PropsWithChildren;
|
||||||
|
|
||||||
|
export const AddImageToBoardContextProvider = (props: Props) => {
|
||||||
|
const [imageToMove, setImageToMove] = useState<ImageDTO>();
|
||||||
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
|
||||||
|
const [addImageToBoard, result] = useAddImageToBoardMutation();
|
||||||
|
|
||||||
|
// Clean up after deleting or dismissing the modal
|
||||||
|
const closeAndClearImageToDelete = useCallback(() => {
|
||||||
|
setImageToMove(undefined);
|
||||||
|
onClose();
|
||||||
|
}, [onClose]);
|
||||||
|
|
||||||
|
const onClickAddToBoard = useCallback(
|
||||||
|
(image?: ImageDTO) => {
|
||||||
|
if (!image) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setImageToMove(image);
|
||||||
|
onOpen();
|
||||||
|
},
|
||||||
|
[setImageToMove, onOpen]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleAddToBoard = useCallback(
|
||||||
|
(boardId: string) => {
|
||||||
|
if (imageToMove) {
|
||||||
|
addImageToBoard({
|
||||||
|
board_id: boardId,
|
||||||
|
image_name: imageToMove.image_name,
|
||||||
|
});
|
||||||
|
closeAndClearImageToDelete();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[addImageToBoard, closeAndClearImageToDelete, imageToMove]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AddImageToBoardContext.Provider
|
||||||
|
value={{
|
||||||
|
isOpen,
|
||||||
|
image: imageToMove,
|
||||||
|
onClose: closeAndClearImageToDelete,
|
||||||
|
onClickAddToBoard,
|
||||||
|
handleAddToBoard,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{props.children}
|
||||||
|
</AddImageToBoardContext.Provider>
|
||||||
|
);
|
||||||
|
};
|
@ -35,25 +35,23 @@ export const selectImageUsage = createSelector(
|
|||||||
(state: RootState, image_name?: string) => image_name,
|
(state: RootState, image_name?: string) => image_name,
|
||||||
],
|
],
|
||||||
(generation, canvas, nodes, controlNet, image_name) => {
|
(generation, canvas, nodes, controlNet, image_name) => {
|
||||||
const isInitialImage = generation.initialImage?.image_name === image_name;
|
const isInitialImage = generation.initialImage?.imageName === image_name;
|
||||||
|
|
||||||
const isCanvasImage = canvas.layerState.objects.some(
|
const isCanvasImage = canvas.layerState.objects.some(
|
||||||
(obj) => obj.kind === 'image' && obj.image.image_name === image_name
|
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const isNodesImage = nodes.nodes.some((node) => {
|
const isNodesImage = nodes.nodes.some((node) => {
|
||||||
return some(
|
return some(
|
||||||
node.data.inputs,
|
node.data.inputs,
|
||||||
(input) =>
|
(input) => input.type === 'image' && input.value === image_name
|
||||||
input.type === 'image' && input.value?.image_name === image_name
|
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
const isControlNetImage = some(
|
const isControlNetImage = some(
|
||||||
controlNet.controlNets,
|
controlNet.controlNets,
|
||||||
(c) =>
|
(c) =>
|
||||||
c.controlImage?.image_name === image_name ||
|
c.controlImage === image_name || c.processedControlImage === image_name
|
||||||
c.processedControlImage?.image_name === image_name
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const imageUsage: ImageUsage = {
|
const imageUsage: ImageUsage = {
|
||||||
|
@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
|
|||||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||||
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||||
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
|
|
||||||
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
||||||
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||||
import { omit } from 'lodash-es';
|
import { omit } from 'lodash-es';
|
||||||
@ -18,7 +17,6 @@ const serializationDenylist: {
|
|||||||
gallery: galleryPersistDenylist,
|
gallery: galleryPersistDenylist,
|
||||||
generation: generationPersistDenylist,
|
generation: generationPersistDenylist,
|
||||||
lightbox: lightboxPersistDenylist,
|
lightbox: lightboxPersistDenylist,
|
||||||
models: modelsPersistDenylist,
|
|
||||||
nodes: nodesPersistDenylist,
|
nodes: nodesPersistDenylist,
|
||||||
postprocessing: postprocessingPersistDenylist,
|
postprocessing: postprocessingPersistDenylist,
|
||||||
system: systemPersistDenylist,
|
system: systemPersistDenylist,
|
||||||
|
@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
|||||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||||
import { initialConfigState } from 'features/system/store/configSlice';
|
import { initialConfigState } from 'features/system/store/configSlice';
|
||||||
import { initialModelsState } from 'features/system/store/modelSlice';
|
|
||||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||||
@ -21,7 +20,6 @@ const initialStates: {
|
|||||||
gallery: initialGalleryState,
|
gallery: initialGalleryState,
|
||||||
generation: initialGenerationState,
|
generation: initialGenerationState,
|
||||||
lightbox: initialLightboxState,
|
lightbox: initialLightboxState,
|
||||||
models: initialModelsState,
|
|
||||||
nodes: initialNodesState,
|
nodes: initialNodesState,
|
||||||
postprocessing: initialPostprocessingState,
|
postprocessing: initialPostprocessingState,
|
||||||
system: initialSystemState,
|
system: initialSystemState,
|
||||||
|
@ -73,6 +73,15 @@ import { addImageCategoriesChangedListener } from './listeners/imageCategoriesCh
|
|||||||
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
||||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||||
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
|
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
|
||||||
|
import {
|
||||||
|
addImageAddedToBoardFulfilledListener,
|
||||||
|
addImageAddedToBoardRejectedListener,
|
||||||
|
} from './listeners/imageAddedToBoard';
|
||||||
|
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||||
|
import {
|
||||||
|
addImageRemovedFromBoardFulfilledListener,
|
||||||
|
addImageRemovedFromBoardRejectedListener,
|
||||||
|
} from './listeners/imageRemovedFromBoard';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -92,6 +101,12 @@ export type AppListenerEffect = ListenerEffect<
|
|||||||
AppDispatch
|
AppDispatch
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
||||||
|
*
|
||||||
|
* Most side effect logic should live in a listener.
|
||||||
|
*/
|
||||||
|
|
||||||
// Image uploaded
|
// Image uploaded
|
||||||
addImageUploadedFulfilledListener();
|
addImageUploadedFulfilledListener();
|
||||||
addImageUploadedRejectedListener();
|
addImageUploadedRejectedListener();
|
||||||
@ -183,3 +198,10 @@ addControlNetAutoProcessListener();
|
|||||||
|
|
||||||
// Update image URLs on connect
|
// Update image URLs on connect
|
||||||
addUpdateImageUrlsOnConnectListener();
|
addUpdateImageUrlsOnConnectListener();
|
||||||
|
|
||||||
|
// Boards
|
||||||
|
addImageAddedToBoardFulfilledListener();
|
||||||
|
addImageAddedToBoardRejectedListener();
|
||||||
|
addImageRemovedFromBoardFulfilledListener();
|
||||||
|
addImageRemovedFromBoardRejectedListener();
|
||||||
|
addBoardIdSelectedListener();
|
||||||
|
@ -0,0 +1,99 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { boardIdSelected } from 'features/gallery/store/boardSlice';
|
||||||
|
import { selectImagesAll } from 'features/gallery/store/imagesSlice';
|
||||||
|
import { IMAGES_PER_PAGE, receivedPageOfImages } from 'services/thunks/image';
|
||||||
|
import { api } from 'services/apiSlice';
|
||||||
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'boards' });
|
||||||
|
|
||||||
|
export const addBoardIdSelectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: boardIdSelected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const boardId = action.payload;
|
||||||
|
|
||||||
|
// we need to check if we need to fetch more images
|
||||||
|
|
||||||
|
const state = getState();
|
||||||
|
const allImages = selectImagesAll(state);
|
||||||
|
|
||||||
|
if (!boardId) {
|
||||||
|
// a board was unselected
|
||||||
|
dispatch(imageSelected(allImages[0]?.image_name));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { categories } = state.images;
|
||||||
|
|
||||||
|
const filteredImages = allImages.filter((i) => {
|
||||||
|
const isInCategory = categories.includes(i.image_category);
|
||||||
|
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
|
||||||
|
return isInCategory && isInSelectedBoard;
|
||||||
|
});
|
||||||
|
|
||||||
|
// get the board from the cache
|
||||||
|
const { data: boards } = api.endpoints.listAllBoards.select()(state);
|
||||||
|
const board = boards?.find((b) => b.board_id === boardId);
|
||||||
|
|
||||||
|
if (!board) {
|
||||||
|
// can't find the board in cache...
|
||||||
|
dispatch(imageSelected(allImages[0]?.image_name));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(imageSelected(board.cover_image_name));
|
||||||
|
|
||||||
|
// if we haven't loaded one full page of images from this board, load more
|
||||||
|
if (
|
||||||
|
filteredImages.length < board.image_count &&
|
||||||
|
filteredImages.length < IMAGES_PER_PAGE
|
||||||
|
) {
|
||||||
|
dispatch(receivedPageOfImages({ categories, boardId }));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addBoardIdSelected_changeSelectedImage_listener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: boardIdSelected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const boardId = action.payload;
|
||||||
|
|
||||||
|
const state = getState();
|
||||||
|
|
||||||
|
// we need to check if we need to fetch more images
|
||||||
|
|
||||||
|
if (!boardId) {
|
||||||
|
// a board was unselected - we don't need to do anything
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { categories } = state.images;
|
||||||
|
|
||||||
|
const filteredImages = selectImagesAll(state).filter((i) => {
|
||||||
|
const isInCategory = categories.includes(i.image_category);
|
||||||
|
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
|
||||||
|
return isInCategory && isInSelectedBoard;
|
||||||
|
});
|
||||||
|
|
||||||
|
// get the board from the cache
|
||||||
|
const { data: boards } = api.endpoints.listAllBoards.select()(state);
|
||||||
|
const board = boards?.find((b) => b.board_id === boardId);
|
||||||
|
if (!board) {
|
||||||
|
// can't find the board in cache...
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if we haven't loaded one full page of images from this board, load more
|
||||||
|
if (
|
||||||
|
filteredImages.length < board.image_count &&
|
||||||
|
filteredImages.length < IMAGES_PER_PAGE
|
||||||
|
) {
|
||||||
|
dispatch(receivedPageOfImages({ categories, boardId }));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => {
|
|||||||
[controlNet.processorNode.id]: {
|
[controlNet.processorNode.id]: {
|
||||||
...controlNet.processorNode,
|
...controlNet.processorNode,
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
image: pick(controlNet.controlImage, ['image_name']),
|
image: { image_name: controlNet.controlImage },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => {
|
|||||||
dispatch(
|
dispatch(
|
||||||
controlNetProcessedImageChanged({
|
controlNetProcessedImageChanged({
|
||||||
controlNetId,
|
controlNetId,
|
||||||
processedControlImage,
|
processedControlImage: processedControlImage.image_name,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageMetadataReceived } from 'services/thunks/image';
|
||||||
|
import { api } from 'services/apiSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'boards' });
|
||||||
|
|
||||||
|
export const addImageAddedToBoardFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
matcher: api.endpoints.addImageToBoard.matchFulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { board_id, image_name } },
|
||||||
|
'Image added to board'
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
imageMetadataReceived({
|
||||||
|
imageName: image_name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addImageAddedToBoardRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
matcher: api.endpoints.addImageToBoard.matchRejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { board_id, image_name } },
|
||||||
|
'Problem adding image to board'
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -12,12 +12,16 @@ export const addImageCategoriesChangedListener = () => {
|
|||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: imageCategoriesChanged,
|
actionCreator: imageCategoriesChanged,
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: (action, { getState, dispatch }) => {
|
||||||
const filteredImagesCount = selectFilteredImagesAsArray(
|
const state = getState();
|
||||||
getState()
|
const filteredImagesCount = selectFilteredImagesAsArray(state).length;
|
||||||
).length;
|
|
||||||
|
|
||||||
if (!filteredImagesCount) {
|
if (!filteredImagesCount) {
|
||||||
dispatch(receivedPageOfImages());
|
dispatch(
|
||||||
|
receivedPageOfImages({
|
||||||
|
categories: action.payload,
|
||||||
|
boardId: state.boards.selectedBoardId,
|
||||||
|
})
|
||||||
|
);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -6,15 +6,15 @@ import { clamp } from 'lodash-es';
|
|||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import {
|
import {
|
||||||
imageRemoved,
|
imageRemoved,
|
||||||
selectImagesEntities,
|
|
||||||
selectImagesIds,
|
selectImagesIds,
|
||||||
} from 'features/gallery/store/imagesSlice';
|
} from 'features/gallery/store/imagesSlice';
|
||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { api } from 'services/apiSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
const moduleLog = log.child({ namespace: 'image' });
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called when the user requests an image deletion
|
* Called when the user requests an image deletion
|
||||||
@ -22,7 +22,7 @@ const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
|||||||
export const addRequestedImageDeletionListener = () => {
|
export const addRequestedImageDeletionListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: requestedImageDeletion,
|
actionCreator: requestedImageDeletion,
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState, condition }) => {
|
||||||
const { image, imageUsage } = action.payload;
|
const { image, imageUsage } = action.payload;
|
||||||
|
|
||||||
const { image_name } = image;
|
const { image_name } = image;
|
||||||
@ -30,9 +30,8 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
const state = getState();
|
const state = getState();
|
||||||
const selectedImage = state.gallery.selectedImage;
|
const selectedImage = state.gallery.selectedImage;
|
||||||
|
|
||||||
if (selectedImage && selectedImage.image_name === image_name) {
|
if (selectedImage === image_name) {
|
||||||
const ids = selectImagesIds(state);
|
const ids = selectImagesIds(state);
|
||||||
const entities = selectImagesEntities(state);
|
|
||||||
|
|
||||||
const deletedImageIndex = ids.findIndex(
|
const deletedImageIndex = ids.findIndex(
|
||||||
(result) => result.toString() === image_name
|
(result) => result.toString() === image_name
|
||||||
@ -48,10 +47,8 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
|
|
||||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||||
|
|
||||||
const newSelectedImage = entities[newSelectedImageId];
|
|
||||||
|
|
||||||
if (newSelectedImageId) {
|
if (newSelectedImageId) {
|
||||||
dispatch(imageSelected(newSelectedImage));
|
dispatch(imageSelected(newSelectedImageId as string));
|
||||||
} else {
|
} else {
|
||||||
dispatch(imageSelected());
|
dispatch(imageSelected());
|
||||||
}
|
}
|
||||||
@ -79,7 +76,21 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
dispatch(imageRemoved(image_name));
|
dispatch(imageRemoved(image_name));
|
||||||
|
|
||||||
// Delete from server
|
// Delete from server
|
||||||
dispatch(imageDeleted({ imageName: image_name }));
|
const { requestId } = dispatch(imageDeleted({ imageName: image_name }));
|
||||||
|
|
||||||
|
// Wait for successful deletion, then trigger boards to re-fetch
|
||||||
|
const wasImageDeleted = await condition(
|
||||||
|
(action): action is ReturnType<typeof imageDeleted.fulfilled> =>
|
||||||
|
imageDeleted.fulfilled.match(action) &&
|
||||||
|
action.meta.requestId === requestId,
|
||||||
|
30000
|
||||||
|
);
|
||||||
|
|
||||||
|
if (wasImageDeleted) {
|
||||||
|
dispatch(
|
||||||
|
api.util.invalidateTags([{ type: 'Board', id: image.board_id }])
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { imageMetadataReceived } from 'services/thunks/image';
|
||||||
|
import { api } from 'services/apiSlice';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'boards' });
|
||||||
|
|
||||||
|
export const addImageRemovedFromBoardFulfilledListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
matcher: api.endpoints.removeImageFromBoard.matchFulfilled,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { board_id, image_name } },
|
||||||
|
'Image added to board'
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
imageMetadataReceived({
|
||||||
|
imageName: image_name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
export const addImageRemovedFromBoardRejectedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
matcher: api.endpoints.removeImageFromBoard.matchRejected,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||||
|
|
||||||
|
moduleLog.debug(
|
||||||
|
{ data: { board_id, image_name } },
|
||||||
|
'Problem adding image to board'
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -46,7 +46,12 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
|
|
||||||
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
|
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
|
||||||
const { controlNetId } = postUploadAction;
|
const { controlNetId } = postUploadAction;
|
||||||
dispatch(controlNetImageChanged({ controlNetId, controlImage: image }));
|
dispatch(
|
||||||
|
controlNetImageChanged({
|
||||||
|
controlNetId,
|
||||||
|
controlImage: image.image_name,
|
||||||
|
})
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import { startAppListening } from '../..';
|
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||||
import { receivedPageOfImages } from 'services/thunks/image';
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||||
|
import { startAppListening } from '../..';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'socketio' });
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
|
|
||||||
@ -15,16 +14,17 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
|
|
||||||
moduleLog.debug({ timestamp }, 'Connected');
|
moduleLog.debug({ timestamp }, 'Connected');
|
||||||
|
|
||||||
const { models, nodes, config, images } = getState();
|
const { nodes, config, images } = getState();
|
||||||
|
|
||||||
const { disabledTabs } = config;
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
if (!images.ids.length) {
|
if (!images.ids.length) {
|
||||||
dispatch(receivedPageOfImages());
|
dispatch(
|
||||||
}
|
receivedPageOfImages({
|
||||||
|
categories: ['general'],
|
||||||
if (!models.ids.length) {
|
isIntermediate: false,
|
||||||
dispatch(receivedModels());
|
})
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||||
|
@ -9,6 +9,7 @@ import { imageMetadataReceived } from 'services/thunks/image';
|
|||||||
import { sessionCanceled } from 'services/thunks/session';
|
import { sessionCanceled } from 'services/thunks/session';
|
||||||
import { isImageOutput } from 'services/types/guards';
|
import { isImageOutput } from 'services/types/guards';
|
||||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||||
|
import { api } from 'services/apiSlice';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'socketio' });
|
const moduleLog = log.child({ namespace: 'socketio' });
|
||||||
const nodeDenylist = ['dataURL_image'];
|
const nodeDenylist = ['dataURL_image'];
|
||||||
@ -24,7 +25,8 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
|
|
||||||
const sessionId = action.payload.data.graph_execution_state_id;
|
const sessionId = action.payload.data.graph_execution_state_id;
|
||||||
|
|
||||||
const { cancelType, isCancelScheduled } = getState().system;
|
const { cancelType, isCancelScheduled, boardIdToAddTo } =
|
||||||
|
getState().system;
|
||||||
|
|
||||||
// Handle scheduled cancelation
|
// Handle scheduled cancelation
|
||||||
if (cancelType === 'scheduled' && isCancelScheduled) {
|
if (cancelType === 'scheduled' && isCancelScheduled) {
|
||||||
@ -57,6 +59,15 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
dispatch(addImageToStagingArea(imageDTO));
|
dispatch(addImageToStagingArea(imageDTO));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (boardIdToAddTo && !imageDTO.is_intermediate) {
|
||||||
|
dispatch(
|
||||||
|
api.endpoints.addImageToBoard.initiate({
|
||||||
|
board_id: boardIdToAddTo,
|
||||||
|
image_name,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(progressImageSet(null));
|
dispatch(progressImageSet(null));
|
||||||
}
|
}
|
||||||
// pass along the socket event as an application action
|
// pass along the socket event as an application action
|
||||||
|
@ -22,15 +22,15 @@ const selectAllUsedImages = createSelector(
|
|||||||
selectImagesEntities,
|
selectImagesEntities,
|
||||||
],
|
],
|
||||||
(generation, canvas, nodes, controlNet, imageEntities) => {
|
(generation, canvas, nodes, controlNet, imageEntities) => {
|
||||||
const allUsedImages: ImageDTO[] = [];
|
const allUsedImages: string[] = [];
|
||||||
|
|
||||||
if (generation.initialImage) {
|
if (generation.initialImage) {
|
||||||
allUsedImages.push(generation.initialImage);
|
allUsedImages.push(generation.initialImage.imageName);
|
||||||
}
|
}
|
||||||
|
|
||||||
canvas.layerState.objects.forEach((obj) => {
|
canvas.layerState.objects.forEach((obj) => {
|
||||||
if (obj.kind === 'image') {
|
if (obj.kind === 'image') {
|
||||||
allUsedImages.push(obj.image);
|
allUsedImages.push(obj.imageName);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ const selectAllUsedImages = createSelector(
|
|||||||
|
|
||||||
forEach(imageEntities, (image) => {
|
forEach(imageEntities, (image) => {
|
||||||
if (image) {
|
if (image) {
|
||||||
allUsedImages.push(image);
|
allUsedImages.push(image.image_name);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => {
|
|||||||
`Fetching new image URLs for ${allUsedImages.length} images`
|
`Fetching new image URLs for ${allUsedImages.length} images`
|
||||||
);
|
);
|
||||||
|
|
||||||
allUsedImages.forEach(({ image_name }) => {
|
allUsedImages.forEach((image_name) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
imageUrlsReceived({
|
imageUrlsReceived({
|
||||||
imageName: image_name,
|
imageName: image_name,
|
||||||
|
@ -5,40 +5,39 @@ import {
|
|||||||
configureStore,
|
configureStore,
|
||||||
} from '@reduxjs/toolkit';
|
} from '@reduxjs/toolkit';
|
||||||
|
|
||||||
import { rememberReducer, rememberEnhancer } from 'redux-remember';
|
|
||||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||||
|
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||||
|
|
||||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||||
|
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||||
import imagesReducer from 'features/gallery/store/imagesSlice';
|
import imagesReducer from 'features/gallery/store/imagesSlice';
|
||||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
// import sessionReducer from 'features/system/store/sessionSlice';
|
// import sessionReducer from 'features/system/store/sessionSlice';
|
||||||
import configReducer from 'features/system/store/configSlice';
|
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
|
||||||
import modelsReducer from 'features/system/store/modelSlice';
|
|
||||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
|
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||||
|
import configReducer from 'features/system/store/configSlice';
|
||||||
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
|
|
||||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||||
|
|
||||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
|
||||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||||
|
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||||
|
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
import { api } from 'services/apiSlice';
|
||||||
|
|
||||||
const allReducers = {
|
const allReducers = {
|
||||||
canvas: canvasReducer,
|
canvas: canvasReducer,
|
||||||
gallery: galleryReducer,
|
gallery: galleryReducer,
|
||||||
generation: generationReducer,
|
generation: generationReducer,
|
||||||
lightbox: lightboxReducer,
|
lightbox: lightboxReducer,
|
||||||
models: modelsReducer,
|
|
||||||
nodes: nodesReducer,
|
nodes: nodesReducer,
|
||||||
postprocessing: postprocessingReducer,
|
postprocessing: postprocessingReducer,
|
||||||
system: systemReducer,
|
system: systemReducer,
|
||||||
@ -47,7 +46,9 @@ const allReducers = {
|
|||||||
hotkeys: hotkeysReducer,
|
hotkeys: hotkeysReducer,
|
||||||
images: imagesReducer,
|
images: imagesReducer,
|
||||||
controlNet: controlNetReducer,
|
controlNet: controlNetReducer,
|
||||||
|
boards: boardsReducer,
|
||||||
// session: sessionReducer,
|
// session: sessionReducer,
|
||||||
|
[api.reducerPath]: api.reducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
const rootReducer = combineReducers(allReducers);
|
const rootReducer = combineReducers(allReducers);
|
||||||
@ -59,12 +60,12 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'gallery',
|
'gallery',
|
||||||
'generation',
|
'generation',
|
||||||
'lightbox',
|
'lightbox',
|
||||||
// 'models',
|
|
||||||
'nodes',
|
'nodes',
|
||||||
'postprocessing',
|
'postprocessing',
|
||||||
'system',
|
'system',
|
||||||
'ui',
|
'ui',
|
||||||
'controlNet',
|
'controlNet',
|
||||||
|
// 'boards',
|
||||||
// 'hotkeys',
|
// 'hotkeys',
|
||||||
// 'config',
|
// 'config',
|
||||||
];
|
];
|
||||||
@ -84,6 +85,7 @@ export const store = configureStore({
|
|||||||
immutableCheck: false,
|
immutableCheck: false,
|
||||||
serializableCheck: false,
|
serializableCheck: false,
|
||||||
})
|
})
|
||||||
|
.concat(api.middleware)
|
||||||
.concat(dynamicMiddlewares)
|
.concat(dynamicMiddlewares)
|
||||||
.prepend(listenerMiddleware.middleware),
|
.prepend(listenerMiddleware.middleware),
|
||||||
devTools: {
|
devTools: {
|
||||||
|
@ -9,7 +9,7 @@ import {
|
|||||||
import { useDraggable, useDroppable } from '@dnd-kit/core';
|
import { useDraggable, useDroppable } from '@dnd-kit/core';
|
||||||
import { useCombinedRefs } from '@dnd-kit/utilities';
|
import { useCombinedRefs } from '@dnd-kit/utilities';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||||
import { AnimatePresence } from 'framer-motion';
|
import { AnimatePresence } from 'framer-motion';
|
||||||
import { ReactElement, SyntheticEvent, useCallback } from 'react';
|
import { ReactElement, SyntheticEvent, useCallback } from 'react';
|
||||||
@ -53,7 +53,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
|||||||
isDropDisabled = false,
|
isDropDisabled = false,
|
||||||
isDragDisabled = false,
|
isDragDisabled = false,
|
||||||
isUploadDisabled = false,
|
isUploadDisabled = false,
|
||||||
fallback = <IAIImageFallback />,
|
fallback = <IAIImageLoadingFallback />,
|
||||||
payloadImage,
|
payloadImage,
|
||||||
minSize = 24,
|
minSize = 24,
|
||||||
postUploadAction,
|
postUploadAction,
|
||||||
|
@ -1,10 +1,20 @@
|
|||||||
import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react';
|
import {
|
||||||
|
As,
|
||||||
|
Flex,
|
||||||
|
FlexProps,
|
||||||
|
Icon,
|
||||||
|
IconProps,
|
||||||
|
Spinner,
|
||||||
|
SpinnerProps,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { ReactElement } from 'react';
|
||||||
|
import { FaImage } from 'react-icons/fa';
|
||||||
|
|
||||||
type Props = FlexProps & {
|
type Props = FlexProps & {
|
||||||
spinnerProps?: SpinnerProps;
|
spinnerProps?: SpinnerProps;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const IAIImageFallback = (props: Props) => {
|
export const IAIImageLoadingFallback = (props: Props) => {
|
||||||
const { spinnerProps, ...rest } = props;
|
const { spinnerProps, ...rest } = props;
|
||||||
const { sx, ...restFlexProps } = rest;
|
const { sx, ...restFlexProps } = rest;
|
||||||
return (
|
return (
|
||||||
@ -25,3 +35,35 @@ export const IAIImageFallback = (props: Props) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type IAINoImageFallbackProps = {
|
||||||
|
flexProps?: FlexProps;
|
||||||
|
iconProps?: IconProps;
|
||||||
|
as?: As;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const IAINoImageFallback = (props: IAINoImageFallbackProps) => {
|
||||||
|
const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} };
|
||||||
|
const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} };
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
bg: 'base.900',
|
||||||
|
opacity: 0.7,
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
...flexSx,
|
||||||
|
}}
|
||||||
|
{...restFlexProps}
|
||||||
|
>
|
||||||
|
<Icon
|
||||||
|
as={props.as ?? FaImage}
|
||||||
|
sx={{ color: 'base.700', ...iconSx }}
|
||||||
|
{...restIconProps}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
@ -1,14 +1,21 @@
|
|||||||
import { Image } from 'react-konva';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
|
import { Image, Rect } from 'react-konva';
|
||||||
|
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||||
import useImage from 'use-image';
|
import useImage from 'use-image';
|
||||||
|
import { CanvasImage } from '../store/canvasTypes';
|
||||||
|
|
||||||
type IAICanvasImageProps = {
|
type IAICanvasImageProps = {
|
||||||
url: string;
|
canvasImage: CanvasImage;
|
||||||
x: number;
|
|
||||||
y: number;
|
|
||||||
};
|
};
|
||||||
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
||||||
const { url, x, y } = props;
|
const { width, height, x, y, imageName } = props.canvasImage;
|
||||||
const [image] = useImage(url, 'anonymous');
|
const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
|
||||||
|
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
|
||||||
|
|
||||||
|
if (!imageDTO) {
|
||||||
|
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
|
||||||
|
}
|
||||||
|
|
||||||
return <Image x={x} y={y} image={image} listening={false} />;
|
return <Image x={x} y={y} image={image} listening={false} />;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => {
|
|||||||
<Group name="outpainting-objects" listening={false}>
|
<Group name="outpainting-objects" listening={false}>
|
||||||
{objects.map((obj, i) => {
|
{objects.map((obj, i) => {
|
||||||
if (isCanvasBaseImage(obj)) {
|
if (isCanvasBaseImage(obj)) {
|
||||||
return (
|
return <IAICanvasImage key={i} canvasImage={obj} />;
|
||||||
<IAICanvasImage
|
|
||||||
key={i}
|
|
||||||
x={obj.x}
|
|
||||||
y={obj.y}
|
|
||||||
url={obj.image.image_url}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
} else if (isCanvasBaseLine(obj)) {
|
} else if (isCanvasBaseLine(obj)) {
|
||||||
const line = (
|
const line = (
|
||||||
<Line
|
<Line
|
||||||
|
@ -59,11 +59,7 @@ const IAICanvasStagingArea = (props: Props) => {
|
|||||||
return (
|
return (
|
||||||
<Group {...rest}>
|
<Group {...rest}>
|
||||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||||
<IAICanvasImage
|
<IAICanvasImage canvasImage={currentStagingAreaImage} />
|
||||||
url={currentStagingAreaImage.image.image_url}
|
|
||||||
x={x}
|
|
||||||
y={y}
|
|
||||||
/>
|
|
||||||
)}
|
)}
|
||||||
{shouldShowStagingOutline && (
|
{shouldShowStagingOutline && (
|
||||||
<Group>
|
<Group>
|
||||||
|
@ -203,7 +203,7 @@ export const canvasSlice = createSlice({
|
|||||||
y: 0,
|
y: 0,
|
||||||
width: width,
|
width: width,
|
||||||
height: height,
|
height: height,
|
||||||
image: image,
|
imageName: image.image_name,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
@ -325,7 +325,7 @@ export const canvasSlice = createSlice({
|
|||||||
kind: 'image',
|
kind: 'image',
|
||||||
layer: 'base',
|
layer: 'base',
|
||||||
...state.layerState.stagingArea.boundingBox,
|
...state.layerState.stagingArea.boundingBox,
|
||||||
image,
|
imageName: image.image_name,
|
||||||
});
|
});
|
||||||
|
|
||||||
state.layerState.stagingArea.selectedImageIndex =
|
state.layerState.stagingArea.selectedImageIndex =
|
||||||
@ -865,25 +865,25 @@ export const canvasSlice = createSlice({
|
|||||||
state.doesCanvasNeedScaling = true;
|
state.doesCanvasNeedScaling = true;
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||||
|
|
||||||
state.layerState.objects.forEach((object) => {
|
// state.layerState.objects.forEach((object) => {
|
||||||
if (object.kind === 'image') {
|
// if (object.kind === 'image') {
|
||||||
if (object.image.image_name === image_name) {
|
// if (object.image.image_name === image_name) {
|
||||||
object.image.image_url = image_url;
|
// object.image.image_url = image_url;
|
||||||
object.image.thumbnail_url = thumbnail_url;
|
// object.image.thumbnail_url = thumbnail_url;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
});
|
// });
|
||||||
|
|
||||||
state.layerState.stagingArea.images.forEach((stagedImage) => {
|
// state.layerState.stagingArea.images.forEach((stagedImage) => {
|
||||||
if (stagedImage.image.image_name === image_name) {
|
// if (stagedImage.image.image_name === image_name) {
|
||||||
stagedImage.image.image_url = image_url;
|
// stagedImage.image.image_url = image_url;
|
||||||
stagedImage.image.thumbnail_url = thumbnail_url;
|
// stagedImage.image.thumbnail_url = thumbnail_url;
|
||||||
}
|
// }
|
||||||
});
|
// });
|
||||||
});
|
// });
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ export type CanvasImage = {
|
|||||||
y: number;
|
y: number;
|
||||||
width: number;
|
width: number;
|
||||||
height: number;
|
height: number;
|
||||||
image: ImageDTO;
|
imageName: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CanvasMaskLine = {
|
export type CanvasMaskLine = {
|
||||||
|
@ -11,9 +11,11 @@ import IAIDndImage from 'common/components/IAIDndImage';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { AnimatePresence, motion } from 'framer-motion';
|
import { AnimatePresence, motion } from 'framer-motion';
|
||||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { FaUndo } from 'react-icons/fa';
|
import { FaUndo } from 'react-icons/fa';
|
||||||
|
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||||
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
controlNetSelector,
|
controlNetSelector,
|
||||||
@ -31,24 +33,45 @@ type Props = {
|
|||||||
|
|
||||||
const ControlNetImagePreview = (props: Props) => {
|
const ControlNetImagePreview = (props: Props) => {
|
||||||
const { imageSx } = props;
|
const { imageSx } = props;
|
||||||
const { controlNetId, controlImage, processedControlImage, processorType } =
|
const {
|
||||||
props.controlNet;
|
controlNetId,
|
||||||
|
controlImage: controlImageName,
|
||||||
|
processedControlImage: processedControlImageName,
|
||||||
|
processorType,
|
||||||
|
} = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { pendingControlImages } = useAppSelector(selector);
|
const { pendingControlImages } = useAppSelector(selector);
|
||||||
|
|
||||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||||
|
|
||||||
|
const {
|
||||||
|
data: controlImage,
|
||||||
|
isLoading: isLoadingControlImage,
|
||||||
|
isError: isErrorControlImage,
|
||||||
|
isSuccess: isSuccessControlImage,
|
||||||
|
} = useGetImageDTOQuery(controlImageName ?? skipToken);
|
||||||
|
|
||||||
|
const {
|
||||||
|
data: processedControlImage,
|
||||||
|
isLoading: isLoadingProcessedControlImage,
|
||||||
|
isError: isErrorProcessedControlImage,
|
||||||
|
isSuccess: isSuccessProcessedControlImage,
|
||||||
|
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
|
||||||
|
|
||||||
const handleDrop = useCallback(
|
const handleDrop = useCallback(
|
||||||
(droppedImage: ImageDTO) => {
|
(droppedImage: ImageDTO) => {
|
||||||
if (controlImage?.image_name === droppedImage.image_name) {
|
if (controlImageName === droppedImage.image_name) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
setIsMouseOverImage(false);
|
setIsMouseOverImage(false);
|
||||||
dispatch(
|
dispatch(
|
||||||
controlNetImageChanged({ controlNetId, controlImage: droppedImage })
|
controlNetImageChanged({
|
||||||
|
controlNetId,
|
||||||
|
controlImage: droppedImage.image_name,
|
||||||
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
[controlImage, controlNetId, dispatch]
|
[controlImageName, controlNetId, dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleResetControlImage = useCallback(() => {
|
const handleResetControlImage = useCallback(() => {
|
||||||
@ -150,7 +173,7 @@ const ControlNetImagePreview = (props: Props) => {
|
|||||||
h: 'full',
|
h: 'full',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<IAIImageFallback />
|
<IAIImageLoadingFallback />
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
{controlImage && (
|
{controlImage && (
|
||||||
|
@ -39,8 +39,8 @@ export type ControlNetConfig = {
|
|||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
controlImage: ImageDTO | null;
|
controlImage: string | null;
|
||||||
processedControlImage: ImageDTO | null;
|
processedControlImage: string | null;
|
||||||
processorType: ControlNetProcessorType;
|
processorType: ControlNetProcessorType;
|
||||||
processorNode: RequiredControlNetProcessorNode;
|
processorNode: RequiredControlNetProcessorNode;
|
||||||
shouldAutoConfig: boolean;
|
shouldAutoConfig: boolean;
|
||||||
@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({
|
|||||||
},
|
},
|
||||||
controlNetAddedFromImage: (
|
controlNetAddedFromImage: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
|
action: PayloadAction<{ controlNetId: string; controlImage: string }>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, controlImage } = action.payload;
|
const { controlNetId, controlImage } = action.payload;
|
||||||
state.controlNets[controlNetId] = {
|
state.controlNets[controlNetId] = {
|
||||||
@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({
|
|||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
controlImage: ImageDTO | null;
|
controlImage: string | null;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, controlImage } = action.payload;
|
const { controlNetId, controlImage } = action.payload;
|
||||||
@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({
|
|||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
controlNetId: string;
|
controlNetId: string;
|
||||||
processedControlImage: ImageDTO | null;
|
processedControlImage: string | null;
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, processedControlImage } = action.payload;
|
const { controlNetId, processedControlImage } = action.payload;
|
||||||
@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({
|
|||||||
// Preemptively remove the image from the gallery
|
// Preemptively remove the image from the gallery
|
||||||
const { imageName } = action.meta.arg;
|
const { imageName } = action.meta.arg;
|
||||||
forEach(state.controlNets, (c) => {
|
forEach(state.controlNets, (c) => {
|
||||||
if (c.controlImage?.image_name === imageName) {
|
if (c.controlImage === imageName) {
|
||||||
c.controlImage = null;
|
c.controlImage = null;
|
||||||
c.processedControlImage = null;
|
c.processedControlImage = null;
|
||||||
}
|
}
|
||||||
if (c.processedControlImage?.image_name === imageName) {
|
if (c.processedControlImage === imageName) {
|
||||||
c.processedControlImage = null;
|
c.processedControlImage = null;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||||
|
|
||||||
forEach(state.controlNets, (c) => {
|
// forEach(state.controlNets, (c) => {
|
||||||
if (c.controlImage?.image_name === image_name) {
|
// if (c.controlImage?.image_name === image_name) {
|
||||||
c.controlImage.image_url = image_url;
|
// c.controlImage.image_url = image_url;
|
||||||
c.controlImage.thumbnail_url = thumbnail_url;
|
// c.controlImage.thumbnail_url = thumbnail_url;
|
||||||
}
|
// }
|
||||||
if (c.processedControlImage?.image_name === image_name) {
|
// if (c.processedControlImage?.image_name === image_name) {
|
||||||
c.processedControlImage.image_url = image_url;
|
// c.processedControlImage.image_url = image_url;
|
||||||
c.processedControlImage.thumbnail_url = thumbnail_url;
|
// c.processedControlImage.thumbnail_url = thumbnail_url;
|
||||||
}
|
// }
|
||||||
});
|
// });
|
||||||
});
|
// });
|
||||||
|
|
||||||
builder.addCase(appSocketInvocationError, (state, action) => {
|
builder.addCase(appSocketInvocationError, (state, action) => {
|
||||||
state.pendingControlImages = [];
|
state.pendingControlImages = [];
|
||||||
|
@ -0,0 +1,27 @@
|
|||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { useCreateBoardMutation } from 'services/apiSlice';
|
||||||
|
|
||||||
|
const DEFAULT_BOARD_NAME = 'My Board';
|
||||||
|
|
||||||
|
const AddBoardButton = () => {
|
||||||
|
const [createBoard, { isLoading }] = useCreateBoardMutation();
|
||||||
|
|
||||||
|
const handleCreateBoard = useCallback(() => {
|
||||||
|
createBoard(DEFAULT_BOARD_NAME);
|
||||||
|
}, [createBoard]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIButton
|
||||||
|
isLoading={isLoading}
|
||||||
|
aria-label="Add Board"
|
||||||
|
onClick={handleCreateBoard}
|
||||||
|
size="sm"
|
||||||
|
sx={{ px: 4 }}
|
||||||
|
>
|
||||||
|
Add Board
|
||||||
|
</IAIButton>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default AddBoardButton;
|
@ -0,0 +1,93 @@
|
|||||||
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
|
import { FaImages } from 'react-icons/fa';
|
||||||
|
import { boardIdSelected } from '../../store/boardSlice';
|
||||||
|
import { useDispatch } from 'react-redux';
|
||||||
|
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { AnimatePresence } from 'framer-motion';
|
||||||
|
import { SelectedItemOverlay } from '../SelectedItemOverlay';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
|
||||||
|
import { useDroppable } from '@dnd-kit/core';
|
||||||
|
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||||
|
|
||||||
|
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
|
||||||
|
const dispatch = useDispatch();
|
||||||
|
|
||||||
|
const handleAllImagesBoardClick = () => {
|
||||||
|
dispatch(boardIdSelected());
|
||||||
|
};
|
||||||
|
|
||||||
|
const [removeImageFromBoard, { isLoading }] =
|
||||||
|
useRemoveImageFromBoardMutation();
|
||||||
|
|
||||||
|
const handleDrop = useCallback(
|
||||||
|
(droppedImage: ImageDTO) => {
|
||||||
|
if (!droppedImage.board_id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
removeImageFromBoard({
|
||||||
|
board_id: droppedImage.board_id,
|
||||||
|
image_name: droppedImage.image_name,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
[removeImageFromBoard]
|
||||||
|
);
|
||||||
|
|
||||||
|
const {
|
||||||
|
isOver,
|
||||||
|
setNodeRef,
|
||||||
|
active: isDropActive,
|
||||||
|
} = useDroppable({
|
||||||
|
id: `board_droppable_all_images`,
|
||||||
|
data: {
|
||||||
|
handleDrop,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDir: 'column',
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
alignItems: 'center',
|
||||||
|
cursor: 'pointer',
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
borderRadius: 'base',
|
||||||
|
}}
|
||||||
|
onClick={handleAllImagesBoardClick}
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
ref={setNodeRef}
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
w: 'full',
|
||||||
|
aspectRatio: '1/1',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaImages} />
|
||||||
|
<AnimatePresence>
|
||||||
|
{isSelected && <SelectedItemOverlay />}
|
||||||
|
</AnimatePresence>
|
||||||
|
<AnimatePresence>
|
||||||
|
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||||
|
</AnimatePresence>
|
||||||
|
</Flex>
|
||||||
|
<Text
|
||||||
|
sx={{
|
||||||
|
color: isSelected ? 'base.50' : 'base.200',
|
||||||
|
fontWeight: isSelected ? 600 : undefined,
|
||||||
|
fontSize: 'xs',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
All Images
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default AllImagesBoard;
|
@ -0,0 +1,134 @@
|
|||||||
|
import {
|
||||||
|
Collapse,
|
||||||
|
Flex,
|
||||||
|
Grid,
|
||||||
|
IconButton,
|
||||||
|
Input,
|
||||||
|
InputGroup,
|
||||||
|
InputRightElement,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import {
|
||||||
|
boardsSelector,
|
||||||
|
setBoardSearchText,
|
||||||
|
} from 'features/gallery/store/boardSlice';
|
||||||
|
import { memo, useState } from 'react';
|
||||||
|
import HoverableBoard from './HoverableBoard';
|
||||||
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
|
import AddBoardButton from './AddBoardButton';
|
||||||
|
import AllImagesBoard from './AllImagesBoard';
|
||||||
|
import { CloseIcon } from '@chakra-ui/icons';
|
||||||
|
import { useListAllBoardsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[boardsSelector],
|
||||||
|
(boardsState) => {
|
||||||
|
const { selectedBoardId, searchText } = boardsState;
|
||||||
|
return { selectedBoardId, searchText };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
isOpen: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const BoardsList = (props: Props) => {
|
||||||
|
const { isOpen } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { selectedBoardId, searchText } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const { data: boards } = useListAllBoardsQuery();
|
||||||
|
|
||||||
|
const filteredBoards = searchText
|
||||||
|
? boards?.filter((board) =>
|
||||||
|
board.board_name.toLowerCase().includes(searchText.toLowerCase())
|
||||||
|
)
|
||||||
|
: boards;
|
||||||
|
|
||||||
|
const [searchMode, setSearchMode] = useState(false);
|
||||||
|
|
||||||
|
const handleBoardSearch = (searchTerm: string) => {
|
||||||
|
setSearchMode(searchTerm.length > 0);
|
||||||
|
dispatch(setBoardSearchText(searchTerm));
|
||||||
|
};
|
||||||
|
const clearBoardSearch = () => {
|
||||||
|
setSearchMode(false);
|
||||||
|
dispatch(setBoardSearchText(''));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Collapse in={isOpen} animateOpacity>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDir: 'column',
|
||||||
|
gap: 2,
|
||||||
|
bg: 'base.800',
|
||||||
|
borderRadius: 'base',
|
||||||
|
p: 2,
|
||||||
|
mt: 2,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||||
|
<InputGroup>
|
||||||
|
<Input
|
||||||
|
placeholder="Search Boards..."
|
||||||
|
value={searchText}
|
||||||
|
onChange={(e) => {
|
||||||
|
handleBoardSearch(e.target.value);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
{searchText && searchText.length && (
|
||||||
|
<InputRightElement>
|
||||||
|
<IconButton
|
||||||
|
onClick={clearBoardSearch}
|
||||||
|
size="xs"
|
||||||
|
variant="ghost"
|
||||||
|
aria-label="Clear Search"
|
||||||
|
icon={<CloseIcon boxSize={3} />}
|
||||||
|
/>
|
||||||
|
</InputRightElement>
|
||||||
|
)}
|
||||||
|
</InputGroup>
|
||||||
|
<AddBoardButton />
|
||||||
|
</Flex>
|
||||||
|
<OverlayScrollbarsComponent
|
||||||
|
defer
|
||||||
|
style={{ height: '100%', width: '100%' }}
|
||||||
|
options={{
|
||||||
|
scrollbars: {
|
||||||
|
visibility: 'auto',
|
||||||
|
autoHide: 'move',
|
||||||
|
autoHideDelay: 1300,
|
||||||
|
theme: 'os-theme-dark',
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Grid
|
||||||
|
className="list-container"
|
||||||
|
sx={{
|
||||||
|
gap: 2,
|
||||||
|
gridTemplateRows: '5.5rem 5.5rem',
|
||||||
|
gridAutoFlow: 'column dense',
|
||||||
|
gridAutoColumns: '4rem',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{!searchMode && <AllImagesBoard isSelected={!selectedBoardId} />}
|
||||||
|
{filteredBoards &&
|
||||||
|
filteredBoards.map((board) => (
|
||||||
|
<HoverableBoard
|
||||||
|
key={board.board_id}
|
||||||
|
board={board}
|
||||||
|
isSelected={selectedBoardId === board.board_id}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</Grid>
|
||||||
|
</OverlayScrollbarsComponent>
|
||||||
|
</Flex>
|
||||||
|
</Collapse>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(BoardsList);
|
@ -0,0 +1,193 @@
|
|||||||
|
import {
|
||||||
|
Badge,
|
||||||
|
Box,
|
||||||
|
Editable,
|
||||||
|
EditableInput,
|
||||||
|
EditablePreview,
|
||||||
|
Flex,
|
||||||
|
Image,
|
||||||
|
MenuItem,
|
||||||
|
MenuList,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { FaFolder, FaTrash } from 'react-icons/fa';
|
||||||
|
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||||
|
import { BoardDTO, ImageDTO } from 'services/api';
|
||||||
|
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { boardIdSelected } from 'features/gallery/store/boardSlice';
|
||||||
|
import {
|
||||||
|
useAddImageToBoardMutation,
|
||||||
|
useDeleteBoardMutation,
|
||||||
|
useGetImageDTOQuery,
|
||||||
|
useUpdateBoardMutation,
|
||||||
|
} from 'services/apiSlice';
|
||||||
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
|
import { useDroppable } from '@dnd-kit/core';
|
||||||
|
import { AnimatePresence } from 'framer-motion';
|
||||||
|
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||||
|
import { SelectedItemOverlay } from '../SelectedItemOverlay';
|
||||||
|
|
||||||
|
interface HoverableBoardProps {
|
||||||
|
board: BoardDTO;
|
||||||
|
isSelected: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { data: coverImage } = useGetImageDTOQuery(
|
||||||
|
board.cover_image_name ?? skipToken
|
||||||
|
);
|
||||||
|
|
||||||
|
const { board_name, board_id } = board;
|
||||||
|
|
||||||
|
const handleSelectBoard = useCallback(() => {
|
||||||
|
dispatch(boardIdSelected(board_id));
|
||||||
|
}, [board_id, dispatch]);
|
||||||
|
|
||||||
|
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
|
||||||
|
useUpdateBoardMutation();
|
||||||
|
|
||||||
|
const [deleteBoard, { isLoading: isDeleteBoardLoading }] =
|
||||||
|
useDeleteBoardMutation();
|
||||||
|
|
||||||
|
const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] =
|
||||||
|
useAddImageToBoardMutation();
|
||||||
|
|
||||||
|
const handleUpdateBoardName = (newBoardName: string) => {
|
||||||
|
updateBoard({ board_id, changes: { board_name: newBoardName } });
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDeleteBoard = useCallback(() => {
|
||||||
|
deleteBoard(board_id);
|
||||||
|
}, [board_id, deleteBoard]);
|
||||||
|
|
||||||
|
const handleDrop = useCallback(
|
||||||
|
(droppedImage: ImageDTO) => {
|
||||||
|
if (droppedImage.board_id === board_id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
addImageToBoard({ board_id, image_name: droppedImage.image_name });
|
||||||
|
},
|
||||||
|
[addImageToBoard, board_id]
|
||||||
|
);
|
||||||
|
|
||||||
|
const {
|
||||||
|
isOver,
|
||||||
|
setNodeRef,
|
||||||
|
active: isDropActive,
|
||||||
|
} = useDroppable({
|
||||||
|
id: `board_droppable_${board_id}`,
|
||||||
|
data: {
|
||||||
|
handleDrop,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box sx={{ touchAction: 'none' }}>
|
||||||
|
<ContextMenu<HTMLDivElement>
|
||||||
|
menuProps={{ size: 'sm', isLazy: true }}
|
||||||
|
renderMenu={() => (
|
||||||
|
<MenuList sx={{ visibility: 'visible !important' }}>
|
||||||
|
<MenuItem
|
||||||
|
sx={{ color: 'error.300' }}
|
||||||
|
icon={<FaTrash />}
|
||||||
|
onClickCapture={handleDeleteBoard}
|
||||||
|
>
|
||||||
|
Delete Board
|
||||||
|
</MenuItem>
|
||||||
|
</MenuList>
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{(ref) => (
|
||||||
|
<Flex
|
||||||
|
key={board_id}
|
||||||
|
userSelect="none"
|
||||||
|
ref={ref}
|
||||||
|
sx={{
|
||||||
|
flexDir: 'column',
|
||||||
|
justifyContent: 'space-between',
|
||||||
|
alignItems: 'center',
|
||||||
|
cursor: 'pointer',
|
||||||
|
w: 'full',
|
||||||
|
h: 'full',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
ref={setNodeRef}
|
||||||
|
onClick={handleSelectBoard}
|
||||||
|
sx={{
|
||||||
|
position: 'relative',
|
||||||
|
justifyContent: 'center',
|
||||||
|
alignItems: 'center',
|
||||||
|
borderRadius: 'base',
|
||||||
|
w: 'full',
|
||||||
|
aspectRatio: '1/1',
|
||||||
|
overflow: 'hidden',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{board.cover_image_name && coverImage?.image_url && (
|
||||||
|
<Image src={coverImage?.image_url} draggable={false} />
|
||||||
|
)}
|
||||||
|
{!(board.cover_image_name && coverImage?.image_url) && (
|
||||||
|
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaFolder} />
|
||||||
|
)}
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
insetInlineEnd: 0,
|
||||||
|
top: 0,
|
||||||
|
p: 1,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Badge variant="solid">{board.image_count}</Badge>
|
||||||
|
</Flex>
|
||||||
|
<AnimatePresence>
|
||||||
|
{isSelected && <SelectedItemOverlay />}
|
||||||
|
</AnimatePresence>
|
||||||
|
<AnimatePresence>
|
||||||
|
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||||
|
</AnimatePresence>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Box sx={{ width: 'full' }}>
|
||||||
|
<Editable
|
||||||
|
defaultValue={board_name}
|
||||||
|
submitOnBlur={false}
|
||||||
|
onSubmit={(nextValue) => {
|
||||||
|
handleUpdateBoardName(nextValue);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<EditablePreview
|
||||||
|
sx={{
|
||||||
|
color: isSelected ? 'base.50' : 'base.200',
|
||||||
|
fontWeight: isSelected ? 600 : undefined,
|
||||||
|
fontSize: 'xs',
|
||||||
|
textAlign: 'center',
|
||||||
|
p: 0,
|
||||||
|
}}
|
||||||
|
noOfLines={1}
|
||||||
|
/>
|
||||||
|
<EditableInput
|
||||||
|
sx={{
|
||||||
|
color: 'base.50',
|
||||||
|
fontSize: 'xs',
|
||||||
|
borderColor: 'base.500',
|
||||||
|
p: 0,
|
||||||
|
outline: 0,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Editable>
|
||||||
|
</Box>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</ContextMenu>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
HoverableBoard.displayName = 'HoverableBoard';
|
||||||
|
|
||||||
|
export default HoverableBoard;
|
@ -0,0 +1,93 @@
|
|||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogBody,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogOverlay,
|
||||||
|
Box,
|
||||||
|
Flex,
|
||||||
|
Spinner,
|
||||||
|
Text,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
|
||||||
|
import { memo, useContext, useRef, useState } from 'react';
|
||||||
|
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { useListAllBoardsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
|
const UpdateImageBoardModal = () => {
|
||||||
|
// const boards = useSelector(selectBoardsAll);
|
||||||
|
const { data: boards, isFetching } = useListAllBoardsQuery();
|
||||||
|
const { isOpen, onClose, handleAddToBoard, image } = useContext(
|
||||||
|
AddImageToBoardContext
|
||||||
|
);
|
||||||
|
const [selectedBoard, setSelectedBoard] = useState<string | null>();
|
||||||
|
|
||||||
|
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||||
|
|
||||||
|
const currentBoard = boards?.find(
|
||||||
|
(board) => board.board_id === image?.board_id
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AlertDialog
|
||||||
|
isOpen={isOpen}
|
||||||
|
leastDestructiveRef={cancelRef}
|
||||||
|
onClose={onClose}
|
||||||
|
isCentered
|
||||||
|
>
|
||||||
|
<AlertDialogOverlay>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||||
|
{currentBoard ? 'Move Image to Board' : 'Add Image to Board'}
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<AlertDialogBody>
|
||||||
|
<Box>
|
||||||
|
<Flex direction="column" gap={3}>
|
||||||
|
{currentBoard && (
|
||||||
|
<Text>
|
||||||
|
Moving this image from{' '}
|
||||||
|
<strong>{currentBoard.board_name}</strong> to
|
||||||
|
</Text>
|
||||||
|
)}
|
||||||
|
{isFetching ? (
|
||||||
|
<Spinner />
|
||||||
|
) : (
|
||||||
|
<IAIMantineSelect
|
||||||
|
placeholder="Select Board"
|
||||||
|
onChange={(v) => setSelectedBoard(v)}
|
||||||
|
value={selectedBoard}
|
||||||
|
data={(boards ?? []).map((board) => ({
|
||||||
|
label: board.board_name,
|
||||||
|
value: board.board_id,
|
||||||
|
}))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</Box>
|
||||||
|
</AlertDialogBody>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<IAIButton onClick={onClose}>Cancel</IAIButton>
|
||||||
|
<IAIButton
|
||||||
|
isDisabled={!selectedBoard}
|
||||||
|
colorScheme="accent"
|
||||||
|
onClick={() => {
|
||||||
|
if (selectedBoard) {
|
||||||
|
handleAddToBoard(selectedBoard);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
ml={3}
|
||||||
|
>
|
||||||
|
{currentBoard ? 'Move' : 'Add'}
|
||||||
|
</IAIButton>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialogOverlay>
|
||||||
|
</AlertDialog>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(UpdateImageBoardModal);
|
@ -51,9 +51,12 @@ import { useAppToaster } from 'app/components/Toaster';
|
|||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
||||||
import { DeleteImageButton } from './DeleteImageModal';
|
import { DeleteImageButton } from './DeleteImageModal';
|
||||||
|
import { selectImagesById } from '../store/imagesSlice';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
|
||||||
const currentImageButtonsSelector = createSelector(
|
const currentImageButtonsSelector = createSelector(
|
||||||
[
|
[
|
||||||
|
(state: RootState) => state,
|
||||||
systemSelector,
|
systemSelector,
|
||||||
gallerySelector,
|
gallerySelector,
|
||||||
postprocessingSelector,
|
postprocessingSelector,
|
||||||
@ -61,7 +64,7 @@ const currentImageButtonsSelector = createSelector(
|
|||||||
lightboxSelector,
|
lightboxSelector,
|
||||||
activeTabNameSelector,
|
activeTabNameSelector,
|
||||||
],
|
],
|
||||||
(system, gallery, postprocessing, ui, lightbox, activeTabName) => {
|
(state, system, gallery, postprocessing, ui, lightbox, activeTabName) => {
|
||||||
const {
|
const {
|
||||||
isProcessing,
|
isProcessing,
|
||||||
isConnected,
|
isConnected,
|
||||||
@ -81,6 +84,8 @@ const currentImageButtonsSelector = createSelector(
|
|||||||
shouldShowProgressInViewer,
|
shouldShowProgressInViewer,
|
||||||
} = ui;
|
} = ui;
|
||||||
|
|
||||||
|
const imageDTO = selectImagesById(state, gallery.selectedImage ?? '');
|
||||||
|
|
||||||
const { selectedImage } = gallery;
|
const { selectedImage } = gallery;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -97,10 +102,10 @@ const currentImageButtonsSelector = createSelector(
|
|||||||
activeTabName,
|
activeTabName,
|
||||||
isLightboxOpen,
|
isLightboxOpen,
|
||||||
shouldHidePreview,
|
shouldHidePreview,
|
||||||
image: selectedImage,
|
image: imageDTO,
|
||||||
seed: selectedImage?.metadata?.seed,
|
seed: imageDTO?.metadata?.seed,
|
||||||
prompt: selectedImage?.metadata?.positive_conditioning,
|
prompt: imageDTO?.metadata?.positive_conditioning,
|
||||||
negativePrompt: selectedImage?.metadata?.negative_conditioning,
|
negativePrompt: imageDTO?.metadata?.negative_conditioning,
|
||||||
shouldShowProgressInViewer,
|
shouldShowProgressInViewer,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
@ -9,12 +9,12 @@ import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
|||||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||||
import { configSelector } from '../../system/store/configSelectors';
|
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
|
||||||
import { imageSelected } from '../store/gallerySlice';
|
import { imageSelected } from '../store/gallerySlice';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||||
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
|
|
||||||
export const imagesSelector = createSelector(
|
export const imagesSelector = createSelector(
|
||||||
[uiSelector, gallerySelector, systemSelector],
|
[uiSelector, gallerySelector, systemSelector],
|
||||||
@ -29,7 +29,7 @@ export const imagesSelector = createSelector(
|
|||||||
return {
|
return {
|
||||||
shouldShowImageDetails,
|
shouldShowImageDetails,
|
||||||
shouldHidePreview,
|
shouldHidePreview,
|
||||||
image: selectedImage,
|
selectedImage,
|
||||||
progressImage,
|
progressImage,
|
||||||
shouldShowProgressInViewer,
|
shouldShowProgressInViewer,
|
||||||
shouldAntialiasProgressImage,
|
shouldAntialiasProgressImage,
|
||||||
@ -45,11 +45,23 @@ export const imagesSelector = createSelector(
|
|||||||
const CurrentImagePreview = () => {
|
const CurrentImagePreview = () => {
|
||||||
const {
|
const {
|
||||||
shouldShowImageDetails,
|
shouldShowImageDetails,
|
||||||
image,
|
selectedImage,
|
||||||
progressImage,
|
progressImage,
|
||||||
shouldShowProgressInViewer,
|
shouldShowProgressInViewer,
|
||||||
shouldAntialiasProgressImage,
|
shouldAntialiasProgressImage,
|
||||||
} = useAppSelector(imagesSelector);
|
} = useAppSelector(imagesSelector);
|
||||||
|
|
||||||
|
// const image = useAppSelector((state: RootState) =>
|
||||||
|
// selectImagesById(state, selectedImage ?? '')
|
||||||
|
// );
|
||||||
|
|
||||||
|
const {
|
||||||
|
data: image,
|
||||||
|
isLoading,
|
||||||
|
isError,
|
||||||
|
isSuccess,
|
||||||
|
} = useGetImageDTOQuery(selectedImage ?? skipToken);
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleDrop = useCallback(
|
const handleDrop = useCallback(
|
||||||
@ -57,7 +69,7 @@ const CurrentImagePreview = () => {
|
|||||||
if (droppedImage.image_name === image?.image_name) {
|
if (droppedImage.image_name === image?.image_name) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(imageSelected(droppedImage));
|
dispatch(imageSelected(droppedImage.image_name));
|
||||||
},
|
},
|
||||||
[dispatch, image?.image_name]
|
[dispatch, image?.image_name]
|
||||||
);
|
);
|
||||||
@ -98,14 +110,14 @@ const CurrentImagePreview = () => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<IAIDndImage
|
<IAIDndImage
|
||||||
image={image}
|
image={selectedImage && image ? image : undefined}
|
||||||
onDrop={handleDrop}
|
onDrop={handleDrop}
|
||||||
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
|
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
|
||||||
isUploadDisabled={true}
|
isUploadDisabled={true}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
{shouldShowImageDetails && image && (
|
{shouldShowImageDetails && image && selectedImage && (
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
@ -119,7 +131,7 @@ const CurrentImagePreview = () => {
|
|||||||
<ImageMetadataViewer image={image} />
|
<ImageMetadataViewer image={image} />
|
||||||
</Box>
|
</Box>
|
||||||
)}
|
)}
|
||||||
{!shouldShowImageDetails && image && (
|
{!shouldShowImageDetails && image && selectedImage && (
|
||||||
<Box
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
|
@ -2,7 +2,14 @@ import { Box, Flex, Icon, Image, MenuItem, MenuList } from '@chakra-ui/react';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { memo, useCallback, useContext, useState } from 'react';
|
import { memo, useCallback, useContext, useState } from 'react';
|
||||||
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
|
import {
|
||||||
|
FaCheck,
|
||||||
|
FaExpand,
|
||||||
|
FaFolder,
|
||||||
|
FaImage,
|
||||||
|
FaShare,
|
||||||
|
FaTrash,
|
||||||
|
} from 'react-icons/fa';
|
||||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||||
import {
|
import {
|
||||||
resizeAndScaleCanvas,
|
resizeAndScaleCanvas,
|
||||||
@ -27,6 +34,8 @@ import { useAppToaster } from 'app/components/Toaster';
|
|||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { useDraggable } from '@dnd-kit/core';
|
import { useDraggable } from '@dnd-kit/core';
|
||||||
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
||||||
|
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
|
||||||
|
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
|
||||||
|
|
||||||
export const selector = createSelector(
|
export const selector = createSelector(
|
||||||
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
||||||
@ -62,17 +71,10 @@ interface HoverableImageProps {
|
|||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const memoEqualityCheck = (
|
|
||||||
prev: HoverableImageProps,
|
|
||||||
next: HoverableImageProps
|
|
||||||
) =>
|
|
||||||
prev.image.image_name === next.image.image_name &&
|
|
||||||
prev.isSelected === next.isSelected;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||||
*/
|
*/
|
||||||
const HoverableImage = memo((props: HoverableImageProps) => {
|
const HoverableImage = (props: HoverableImageProps) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const {
|
const {
|
||||||
activeTabName,
|
activeTabName,
|
||||||
@ -93,6 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||||
|
|
||||||
const { onDelete } = useContext(DeleteImageContext);
|
const { onDelete } = useContext(DeleteImageContext);
|
||||||
|
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
onDelete(image);
|
onDelete(image);
|
||||||
}, [image, onDelete]);
|
}, [image, onDelete]);
|
||||||
@ -106,11 +109,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const [removeFromBoard] = useRemoveImageFromBoardMutation();
|
||||||
|
|
||||||
const handleMouseOver = () => setIsHovered(true);
|
const handleMouseOver = () => setIsHovered(true);
|
||||||
const handleMouseOut = () => setIsHovered(false);
|
const handleMouseOut = () => setIsHovered(false);
|
||||||
|
|
||||||
const handleSelectImage = useCallback(() => {
|
const handleSelectImage = useCallback(() => {
|
||||||
dispatch(imageSelected(image));
|
dispatch(imageSelected(image.image_name));
|
||||||
}, [image, dispatch]);
|
}, [image, dispatch]);
|
||||||
|
|
||||||
// Recall parameters handlers
|
// Recall parameters handlers
|
||||||
@ -168,6 +173,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
// dispatch(setIsLightboxOpen(true));
|
// dispatch(setIsLightboxOpen(true));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const handleAddToBoard = useCallback(() => {
|
||||||
|
onClickAddToBoard(image);
|
||||||
|
}, [image, onClickAddToBoard]);
|
||||||
|
|
||||||
|
const handleRemoveFromBoard = useCallback(() => {
|
||||||
|
if (!image.board_id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
|
||||||
|
}, [image.board_id, image.image_name, removeFromBoard]);
|
||||||
|
|
||||||
const handleOpenInNewTab = () => {
|
const handleOpenInNewTab = () => {
|
||||||
window.open(image.image_url, '_blank');
|
window.open(image.image_url, '_blank');
|
||||||
};
|
};
|
||||||
@ -244,6 +260,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
{t('parameters.sendToUnifiedCanvas')}
|
{t('parameters.sendToUnifiedCanvas')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
)}
|
)}
|
||||||
|
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
|
||||||
|
{image.board_id ? 'Change Board' : 'Add to Board'}
|
||||||
|
</MenuItem>
|
||||||
|
{image.board_id && (
|
||||||
|
<MenuItem
|
||||||
|
icon={<FaFolder />}
|
||||||
|
onClickCapture={handleRemoveFromBoard}
|
||||||
|
>
|
||||||
|
Remove from Board
|
||||||
|
</MenuItem>
|
||||||
|
)}
|
||||||
<MenuItem
|
<MenuItem
|
||||||
sx={{ color: 'error.300' }}
|
sx={{ color: 'error.300' }}
|
||||||
icon={<FaTrash />}
|
icon={<FaTrash />}
|
||||||
@ -339,8 +366,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
</ContextMenu>
|
</ContextMenu>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
}, memoEqualityCheck);
|
};
|
||||||
|
|
||||||
HoverableImage.displayName = 'HoverableImage';
|
export default memo(HoverableImage);
|
||||||
|
|
||||||
export default HoverableImage;
|
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
import {
|
import {
|
||||||
Box,
|
Box,
|
||||||
|
Button,
|
||||||
ButtonGroup,
|
ButtonGroup,
|
||||||
Flex,
|
Flex,
|
||||||
FlexProps,
|
FlexProps,
|
||||||
Grid,
|
Grid,
|
||||||
Icon,
|
Icon,
|
||||||
Text,
|
Text,
|
||||||
|
VStack,
|
||||||
forwardRef,
|
forwardRef,
|
||||||
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
@ -20,6 +23,7 @@ import {
|
|||||||
setGalleryImageObjectFit,
|
setGalleryImageObjectFit,
|
||||||
setShouldAutoSwitchToNewImages,
|
setShouldAutoSwitchToNewImages,
|
||||||
setShouldUseSingleGalleryColumn,
|
setShouldUseSingleGalleryColumn,
|
||||||
|
setGalleryView,
|
||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
|
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
|
||||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||||
@ -53,41 +57,51 @@ import {
|
|||||||
selectImagesAll,
|
selectImagesAll,
|
||||||
} from '../store/imagesSlice';
|
} from '../store/imagesSlice';
|
||||||
import { receivedPageOfImages } from 'services/thunks/image';
|
import { receivedPageOfImages } from 'services/thunks/image';
|
||||||
|
import BoardsList from './Boards/BoardsList';
|
||||||
|
import { boardsSelector } from '../store/boardSlice';
|
||||||
|
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||||
|
import { useListAllBoardsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
const categorySelector = createSelector(
|
const itemSelector = createSelector(
|
||||||
[(state: RootState) => state],
|
[(state: RootState) => state],
|
||||||
(state) => {
|
(state) => {
|
||||||
const { images } = state;
|
const { categories, total: allImagesTotal, isLoading } = state.images;
|
||||||
const { categories } = images;
|
const { selectedBoardId } = state.boards;
|
||||||
|
|
||||||
const allImages = selectImagesAll(state);
|
const allImages = selectImagesAll(state);
|
||||||
const filteredImages = allImages.filter((i) =>
|
|
||||||
categories.includes(i.image_category)
|
const images = allImages.filter((i) => {
|
||||||
);
|
const isInCategory = categories.includes(i.image_category);
|
||||||
|
const isInSelectedBoard = selectedBoardId
|
||||||
|
? i.board_id === selectedBoardId
|
||||||
|
: true;
|
||||||
|
return isInCategory && isInSelectedBoard;
|
||||||
|
});
|
||||||
|
|
||||||
return {
|
return {
|
||||||
images: filteredImages,
|
images,
|
||||||
isLoading: images.isLoading,
|
allImagesTotal,
|
||||||
areMoreImagesAvailable: filteredImages.length < images.total,
|
isLoading,
|
||||||
categories: images.categories,
|
categories,
|
||||||
|
selectedBoardId,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const mainSelector = createSelector(
|
const mainSelector = createSelector(
|
||||||
[gallerySelector, uiSelector],
|
[gallerySelector, uiSelector, boardsSelector],
|
||||||
(gallery, ui) => {
|
(gallery, ui, boards) => {
|
||||||
const {
|
const {
|
||||||
galleryImageMinimumWidth,
|
galleryImageMinimumWidth,
|
||||||
galleryImageObjectFit,
|
galleryImageObjectFit,
|
||||||
shouldAutoSwitchToNewImages,
|
shouldAutoSwitchToNewImages,
|
||||||
shouldUseSingleGalleryColumn,
|
shouldUseSingleGalleryColumn,
|
||||||
selectedImage,
|
selectedImage,
|
||||||
|
galleryView,
|
||||||
} = gallery;
|
} = gallery;
|
||||||
|
|
||||||
const { shouldPinGallery } = ui;
|
const { shouldPinGallery } = ui;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
shouldPinGallery,
|
shouldPinGallery,
|
||||||
galleryImageMinimumWidth,
|
galleryImageMinimumWidth,
|
||||||
@ -95,6 +109,8 @@ const mainSelector = createSelector(
|
|||||||
shouldAutoSwitchToNewImages,
|
shouldAutoSwitchToNewImages,
|
||||||
shouldUseSingleGalleryColumn,
|
shouldUseSingleGalleryColumn,
|
||||||
selectedImage,
|
selectedImage,
|
||||||
|
galleryView,
|
||||||
|
selectedBoardId: boards.selectedBoardId,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
@ -126,21 +142,44 @@ const ImageGalleryContent = () => {
|
|||||||
shouldAutoSwitchToNewImages,
|
shouldAutoSwitchToNewImages,
|
||||||
shouldUseSingleGalleryColumn,
|
shouldUseSingleGalleryColumn,
|
||||||
selectedImage,
|
selectedImage,
|
||||||
|
galleryView,
|
||||||
} = useAppSelector(mainSelector);
|
} = useAppSelector(mainSelector);
|
||||||
|
|
||||||
const { images, areMoreImagesAvailable, isLoading, categories } =
|
const { images, isLoading, allImagesTotal, categories, selectedBoardId } =
|
||||||
useAppSelector(categorySelector);
|
useAppSelector(itemSelector);
|
||||||
|
|
||||||
|
const { selectedBoard } = useListAllBoardsQuery(undefined, {
|
||||||
|
selectFromResult: ({ data }) => ({
|
||||||
|
selectedBoard: data?.find((b) => b.board_id === selectedBoardId),
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const filteredImagesTotal = useMemo(
|
||||||
|
() => selectedBoard?.image_count ?? allImagesTotal,
|
||||||
|
[allImagesTotal, selectedBoard?.image_count]
|
||||||
|
);
|
||||||
|
|
||||||
|
const areMoreAvailable = useMemo(() => {
|
||||||
|
return images.length < filteredImagesTotal;
|
||||||
|
}, [images.length, filteredImagesTotal]);
|
||||||
|
|
||||||
const handleLoadMoreImages = useCallback(() => {
|
const handleLoadMoreImages = useCallback(() => {
|
||||||
dispatch(receivedPageOfImages());
|
dispatch(
|
||||||
}, [dispatch]);
|
receivedPageOfImages({
|
||||||
|
categories,
|
||||||
|
boardId: selectedBoardId,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}, [categories, dispatch, selectedBoardId]);
|
||||||
|
|
||||||
const handleEndReached = useMemo(() => {
|
const handleEndReached = useMemo(() => {
|
||||||
if (areMoreImagesAvailable && !isLoading) {
|
if (areMoreAvailable && !isLoading) {
|
||||||
return handleLoadMoreImages;
|
return handleLoadMoreImages;
|
||||||
}
|
}
|
||||||
return undefined;
|
return undefined;
|
||||||
}, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]);
|
}, [areMoreAvailable, handleLoadMoreImages, isLoading]);
|
||||||
|
|
||||||
|
const { isOpen: isBoardListOpen, onToggle } = useDisclosure();
|
||||||
|
|
||||||
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
||||||
dispatch(setGalleryImageMinimumWidth(v));
|
dispatch(setGalleryImageMinimumWidth(v));
|
||||||
@ -172,46 +211,79 @@ const ImageGalleryContent = () => {
|
|||||||
|
|
||||||
const handleClickImagesCategory = useCallback(() => {
|
const handleClickImagesCategory = useCallback(() => {
|
||||||
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
|
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
|
||||||
|
dispatch(setGalleryView('images'));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleClickAssetsCategory = useCallback(() => {
|
const handleClickAssetsCategory = useCallback(() => {
|
||||||
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
|
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
|
||||||
|
dispatch(setGalleryView('assets'));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex
|
<VStack
|
||||||
sx={{
|
sx={{
|
||||||
gap: 2,
|
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
h: 'full',
|
h: 'full',
|
||||||
w: 'full',
|
w: 'full',
|
||||||
borderRadius: 'base',
|
borderRadius: 'base',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Flex
|
<Box sx={{ w: 'full' }}>
|
||||||
ref={resizeObserverRef}
|
<Flex
|
||||||
alignItems="center"
|
ref={resizeObserverRef}
|
||||||
justifyContent="space-between"
|
sx={{
|
||||||
>
|
alignItems: 'center',
|
||||||
<ButtonGroup isAttached>
|
justifyContent: 'space-between',
|
||||||
<IAIIconButton
|
gap: 2,
|
||||||
tooltip={t('gallery.images')}
|
}}
|
||||||
aria-label={t('gallery.images')}
|
>
|
||||||
onClick={handleClickImagesCategory}
|
<ButtonGroup isAttached>
|
||||||
isChecked={categories === IMAGE_CATEGORIES}
|
<IAIIconButton
|
||||||
|
tooltip={t('gallery.images')}
|
||||||
|
aria-label={t('gallery.images')}
|
||||||
|
onClick={handleClickImagesCategory}
|
||||||
|
isChecked={galleryView === 'images'}
|
||||||
|
size="sm"
|
||||||
|
icon={<FaImage />}
|
||||||
|
/>
|
||||||
|
<IAIIconButton
|
||||||
|
tooltip={t('gallery.assets')}
|
||||||
|
aria-label={t('gallery.assets')}
|
||||||
|
onClick={handleClickAssetsCategory}
|
||||||
|
isChecked={galleryView === 'assets'}
|
||||||
|
size="sm"
|
||||||
|
icon={<FaServer />}
|
||||||
|
/>
|
||||||
|
</ButtonGroup>
|
||||||
|
<Flex
|
||||||
|
as={Button}
|
||||||
|
onClick={onToggle}
|
||||||
size="sm"
|
size="sm"
|
||||||
icon={<FaImage />}
|
variant="ghost"
|
||||||
/>
|
sx={{
|
||||||
<IAIIconButton
|
w: 'full',
|
||||||
tooltip={t('gallery.assets')}
|
justifyContent: 'center',
|
||||||
aria-label={t('gallery.assets')}
|
alignItems: 'center',
|
||||||
onClick={handleClickAssetsCategory}
|
px: 2,
|
||||||
isChecked={categories === ASSETS_CATEGORIES}
|
_hover: {
|
||||||
size="sm"
|
bg: 'base.800',
|
||||||
icon={<FaServer />}
|
},
|
||||||
/>
|
}}
|
||||||
</ButtonGroup>
|
>
|
||||||
<Flex gap={2}>
|
<Text
|
||||||
|
noOfLines={1}
|
||||||
|
sx={{ w: 'full', color: 'base.200', fontWeight: 600 }}
|
||||||
|
>
|
||||||
|
{selectedBoard ? selectedBoard.board_name : 'All Images'}
|
||||||
|
</Text>
|
||||||
|
<ChevronUpIcon
|
||||||
|
sx={{
|
||||||
|
transform: isBoardListOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||||
|
transitionProperty: 'common',
|
||||||
|
transitionDuration: 'normal',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
<IAIPopover
|
<IAIPopover
|
||||||
triggerComponent={
|
triggerComponent={
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
@ -269,9 +341,12 @@ const ImageGalleryContent = () => {
|
|||||||
icon={shouldPinGallery ? <BsPinAngleFill /> : <BsPinAngle />}
|
icon={shouldPinGallery ? <BsPinAngleFill /> : <BsPinAngle />}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
<Box>
|
||||||
<Flex direction="column" gap={2} h="full">
|
<BoardsList isOpen={isBoardListOpen} />
|
||||||
{images.length || areMoreImagesAvailable ? (
|
</Box>
|
||||||
|
</Box>
|
||||||
|
<Flex direction="column" gap={2} h="full" w="full">
|
||||||
|
{images.length || areMoreAvailable ? (
|
||||||
<>
|
<>
|
||||||
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
|
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
|
||||||
{shouldUseSingleGalleryColumn ? (
|
{shouldUseSingleGalleryColumn ? (
|
||||||
@ -280,14 +355,12 @@ const ImageGalleryContent = () => {
|
|||||||
data={images}
|
data={images}
|
||||||
endReached={handleEndReached}
|
endReached={handleEndReached}
|
||||||
scrollerRef={(ref) => setScrollerRef(ref)}
|
scrollerRef={(ref) => setScrollerRef(ref)}
|
||||||
itemContent={(index, image) => (
|
itemContent={(index, item) => (
|
||||||
<Flex sx={{ pb: 2 }}>
|
<Flex sx={{ pb: 2 }}>
|
||||||
<HoverableImage
|
<HoverableImage
|
||||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
key={`${item.image_name}-${item.thumbnail_url}`}
|
||||||
image={image}
|
image={item}
|
||||||
isSelected={
|
isSelected={selectedImage === item?.image_name}
|
||||||
selectedImage?.image_name === image?.image_name
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
@ -302,13 +375,11 @@ const ImageGalleryContent = () => {
|
|||||||
List: ListContainer,
|
List: ListContainer,
|
||||||
}}
|
}}
|
||||||
scrollerRef={setScroller}
|
scrollerRef={setScroller}
|
||||||
itemContent={(index, image) => (
|
itemContent={(index, item) => (
|
||||||
<HoverableImage
|
<HoverableImage
|
||||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
key={`${item.image_name}-${item.thumbnail_url}`}
|
||||||
image={image}
|
image={item}
|
||||||
isSelected={
|
isSelected={selectedImage === item?.image_name}
|
||||||
selectedImage?.image_name === image?.image_name
|
|
||||||
}
|
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
@ -316,12 +387,12 @@ const ImageGalleryContent = () => {
|
|||||||
</Box>
|
</Box>
|
||||||
<IAIButton
|
<IAIButton
|
||||||
onClick={handleLoadMoreImages}
|
onClick={handleLoadMoreImages}
|
||||||
isDisabled={!areMoreImagesAvailable}
|
isDisabled={!areMoreAvailable}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
loadingText="Loading"
|
loadingText="Loading"
|
||||||
flexShrink={0}
|
flexShrink={0}
|
||||||
>
|
>
|
||||||
{areMoreImagesAvailable
|
{areMoreAvailable
|
||||||
? t('gallery.loadMore')
|
? t('gallery.loadMore')
|
||||||
: t('gallery.allImagesLoaded')}
|
: t('gallery.allImagesLoaded')}
|
||||||
</IAIButton>
|
</IAIButton>
|
||||||
@ -350,7 +421,7 @@ const ImageGalleryContent = () => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</VStack>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -93,19 +93,11 @@ type ImageMetadataViewerProps = {
|
|||||||
image: ImageDTO;
|
image: ImageDTO;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: I don't know if this is needed.
|
|
||||||
const memoEqualityCheck = (
|
|
||||||
prev: ImageMetadataViewerProps,
|
|
||||||
next: ImageMetadataViewerProps
|
|
||||||
) => prev.image.image_name === next.image.image_name;
|
|
||||||
|
|
||||||
// TODO: Show more interesting information in this component.
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image metadata viewer overlays currently selected image and provides
|
* Image metadata viewer overlays currently selected image and provides
|
||||||
* access to any of its metadata for use in processing.
|
* access to any of its metadata for use in processing.
|
||||||
*/
|
*/
|
||||||
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const {
|
const {
|
||||||
recallBothPrompts,
|
recallBothPrompts,
|
||||||
@ -333,8 +325,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
}, memoEqualityCheck);
|
};
|
||||||
|
|
||||||
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
|
export default memo(ImageMetadataViewer);
|
||||||
|
|
||||||
export default ImageMetadataViewer;
|
|
||||||
|
@ -42,7 +42,7 @@ export const nextPrevImageButtonsSelector = createSelector(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const currentImageIndex = filteredImageIds.findIndex(
|
const currentImageIndex = filteredImageIds.findIndex(
|
||||||
(i) => i === selectedImage.image_name
|
(i) => i === selectedImage
|
||||||
);
|
);
|
||||||
|
|
||||||
const nextImageIndex = clamp(
|
const nextImageIndex = clamp(
|
||||||
@ -71,6 +71,8 @@ export const nextPrevImageButtonsSelector = createSelector(
|
|||||||
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
|
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
|
||||||
nextImage,
|
nextImage,
|
||||||
prevImage,
|
prevImage,
|
||||||
|
nextImageId,
|
||||||
|
prevImageId,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -84,7 +86,7 @@ const NextPrevImageButtons = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { isOnFirstImage, isOnLastImage, nextImage, prevImage } =
|
const { isOnFirstImage, isOnLastImage, nextImageId, prevImageId } =
|
||||||
useAppSelector(nextPrevImageButtonsSelector);
|
useAppSelector(nextPrevImageButtonsSelector);
|
||||||
|
|
||||||
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
|
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
|
||||||
@ -99,19 +101,19 @@ const NextPrevImageButtons = () => {
|
|||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const handlePrevImage = useCallback(() => {
|
const handlePrevImage = useCallback(() => {
|
||||||
dispatch(imageSelected(prevImage));
|
dispatch(imageSelected(prevImageId));
|
||||||
}, [dispatch, prevImage]);
|
}, [dispatch, prevImageId]);
|
||||||
|
|
||||||
const handleNextImage = useCallback(() => {
|
const handleNextImage = useCallback(() => {
|
||||||
dispatch(imageSelected(nextImage));
|
dispatch(imageSelected(nextImageId));
|
||||||
}, [dispatch, nextImage]);
|
}, [dispatch, nextImageId]);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
'left',
|
'left',
|
||||||
() => {
|
() => {
|
||||||
handlePrevImage();
|
handlePrevImage();
|
||||||
},
|
},
|
||||||
[prevImage]
|
[prevImageId]
|
||||||
);
|
);
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
@ -119,7 +121,7 @@ const NextPrevImageButtons = () => {
|
|||||||
() => {
|
() => {
|
||||||
handleNextImage();
|
handleNextImage();
|
||||||
},
|
},
|
||||||
[nextImage]
|
[nextImageId]
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
import { motion } from 'framer-motion';
|
||||||
|
|
||||||
|
export const SelectedItemOverlay = () => (
|
||||||
|
<motion.div
|
||||||
|
initial={{
|
||||||
|
opacity: 0,
|
||||||
|
}}
|
||||||
|
animate={{
|
||||||
|
opacity: 1,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
exit={{
|
||||||
|
opacity: 0,
|
||||||
|
transition: { duration: 0.1 },
|
||||||
|
}}
|
||||||
|
style={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
insetInlineStart: 0,
|
||||||
|
width: '100%',
|
||||||
|
height: '100%',
|
||||||
|
boxShadow: 'inset 0px 0px 0px 2px var(--invokeai-colors-accent-300)',
|
||||||
|
borderRadius: 'var(--invokeai-radii-base)',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
);
|
@ -0,0 +1,23 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { selectBoardsAll } from './boardSlice';
|
||||||
|
|
||||||
|
export const boardSelector = (state: RootState) => state.boards.entities;
|
||||||
|
|
||||||
|
export const searchBoardsSelector = createSelector(
|
||||||
|
(state: RootState) => state,
|
||||||
|
(state) => {
|
||||||
|
const {
|
||||||
|
boards: { searchText },
|
||||||
|
} = state;
|
||||||
|
|
||||||
|
if (!searchText) {
|
||||||
|
// If no search text provided, return all entities
|
||||||
|
return selectBoardsAll(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
return selectBoardsAll(state).filter((i) =>
|
||||||
|
i.board_name.toLowerCase().includes(searchText.toLowerCase())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
@ -0,0 +1,47 @@
|
|||||||
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { api } from 'services/apiSlice';
|
||||||
|
|
||||||
|
type BoardsState = {
|
||||||
|
searchText: string;
|
||||||
|
selectedBoardId?: string;
|
||||||
|
updateBoardModalOpen: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const initialBoardsState: BoardsState = {
|
||||||
|
updateBoardModalOpen: false,
|
||||||
|
searchText: '',
|
||||||
|
};
|
||||||
|
|
||||||
|
const boardsSlice = createSlice({
|
||||||
|
name: 'boards',
|
||||||
|
initialState: initialBoardsState,
|
||||||
|
reducers: {
|
||||||
|
boardIdSelected: (state, action: PayloadAction<string | undefined>) => {
|
||||||
|
state.selectedBoardId = action.payload;
|
||||||
|
},
|
||||||
|
setBoardSearchText: (state, action: PayloadAction<string>) => {
|
||||||
|
state.searchText = action.payload;
|
||||||
|
},
|
||||||
|
setUpdateBoardModalOpen: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.updateBoardModalOpen = action.payload;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
extraReducers: (builder) => {
|
||||||
|
builder.addMatcher(
|
||||||
|
api.endpoints.deleteBoard.matchFulfilled,
|
||||||
|
(state, action) => {
|
||||||
|
if (action.meta.arg.originalArgs === state.selectedBoardId) {
|
||||||
|
state.selectedBoardId = undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
);
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } =
|
||||||
|
boardsSlice.actions;
|
||||||
|
|
||||||
|
export const boardsSelector = (state: RootState) => state.boards;
|
||||||
|
|
||||||
|
export default boardsSlice.reducer;
|
@ -1,17 +1,16 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { ImageDTO } from 'services/api';
|
|
||||||
import { imageUpserted } from './imagesSlice';
|
import { imageUpserted } from './imagesSlice';
|
||||||
import { imageUrlsReceived } from 'services/thunks/image';
|
|
||||||
|
|
||||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||||
|
|
||||||
export interface GalleryState {
|
export interface GalleryState {
|
||||||
selectedImage?: ImageDTO;
|
selectedImage?: string;
|
||||||
galleryImageMinimumWidth: number;
|
galleryImageMinimumWidth: number;
|
||||||
galleryImageObjectFit: GalleryImageObjectFitType;
|
galleryImageObjectFit: GalleryImageObjectFitType;
|
||||||
shouldAutoSwitchToNewImages: boolean;
|
shouldAutoSwitchToNewImages: boolean;
|
||||||
shouldUseSingleGalleryColumn: boolean;
|
shouldUseSingleGalleryColumn: boolean;
|
||||||
|
galleryView: 'images' | 'assets' | 'boards';
|
||||||
}
|
}
|
||||||
|
|
||||||
export const initialGalleryState: GalleryState = {
|
export const initialGalleryState: GalleryState = {
|
||||||
@ -19,13 +18,14 @@ export const initialGalleryState: GalleryState = {
|
|||||||
galleryImageObjectFit: 'cover',
|
galleryImageObjectFit: 'cover',
|
||||||
shouldAutoSwitchToNewImages: true,
|
shouldAutoSwitchToNewImages: true,
|
||||||
shouldUseSingleGalleryColumn: false,
|
shouldUseSingleGalleryColumn: false,
|
||||||
|
galleryView: 'images',
|
||||||
};
|
};
|
||||||
|
|
||||||
export const gallerySlice = createSlice({
|
export const gallerySlice = createSlice({
|
||||||
name: 'gallery',
|
name: 'gallery',
|
||||||
initialState: initialGalleryState,
|
initialState: initialGalleryState,
|
||||||
reducers: {
|
reducers: {
|
||||||
imageSelected: (state, action: PayloadAction<ImageDTO | undefined>) => {
|
imageSelected: (state, action: PayloadAction<string | undefined>) => {
|
||||||
state.selectedImage = action.payload;
|
state.selectedImage = action.payload;
|
||||||
// TODO: if the user selects an image, disable the auto switch?
|
// TODO: if the user selects an image, disable the auto switch?
|
||||||
// state.shouldAutoSwitchToNewImages = false;
|
// state.shouldAutoSwitchToNewImages = false;
|
||||||
@ -48,6 +48,12 @@ export const gallerySlice = createSlice({
|
|||||||
) => {
|
) => {
|
||||||
state.shouldUseSingleGalleryColumn = action.payload;
|
state.shouldUseSingleGalleryColumn = action.payload;
|
||||||
},
|
},
|
||||||
|
setGalleryView: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<'images' | 'assets' | 'boards'>
|
||||||
|
) => {
|
||||||
|
state.galleryView = action.payload;
|
||||||
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(imageUpserted, (state, action) => {
|
builder.addCase(imageUpserted, (state, action) => {
|
||||||
@ -55,17 +61,17 @@ export const gallerySlice = createSlice({
|
|||||||
state.shouldAutoSwitchToNewImages &&
|
state.shouldAutoSwitchToNewImages &&
|
||||||
action.payload.image_category === 'general'
|
action.payload.image_category === 'general'
|
||||||
) {
|
) {
|
||||||
state.selectedImage = action.payload;
|
state.selectedImage = action.payload.image_name;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||||
|
|
||||||
if (state.selectedImage?.image_name === image_name) {
|
// if (state.selectedImage?.image_name === image_name) {
|
||||||
state.selectedImage.image_url = image_url;
|
// state.selectedImage.image_url = image_url;
|
||||||
state.selectedImage.thumbnail_url = thumbnail_url;
|
// state.selectedImage.thumbnail_url = thumbnail_url;
|
||||||
}
|
// }
|
||||||
});
|
// });
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -75,6 +81,7 @@ export const {
|
|||||||
setGalleryImageObjectFit,
|
setGalleryImageObjectFit,
|
||||||
setShouldAutoSwitchToNewImages,
|
setShouldAutoSwitchToNewImages,
|
||||||
setShouldUseSingleGalleryColumn,
|
setShouldUseSingleGalleryColumn,
|
||||||
|
setGalleryView,
|
||||||
} = gallerySlice.actions;
|
} = gallerySlice.actions;
|
||||||
|
|
||||||
export default gallerySlice.reducer;
|
export default gallerySlice.reducer;
|
||||||
|
@ -11,7 +11,6 @@ import { dateComparator } from 'common/util/dateComparator';
|
|||||||
import { keyBy } from 'lodash-es';
|
import { keyBy } from 'lodash-es';
|
||||||
import {
|
import {
|
||||||
imageDeleted,
|
imageDeleted,
|
||||||
imageMetadataReceived,
|
|
||||||
imageUrlsReceived,
|
imageUrlsReceived,
|
||||||
receivedPageOfImages,
|
receivedPageOfImages,
|
||||||
} from 'services/thunks/image';
|
} from 'services/thunks/image';
|
||||||
@ -74,11 +73,21 @@ const imagesSlice = createSlice({
|
|||||||
});
|
});
|
||||||
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
|
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
|
||||||
state.isLoading = false;
|
state.isLoading = false;
|
||||||
|
const { boardId, categories, imageOrigin, isIntermediate } =
|
||||||
|
action.meta.arg;
|
||||||
|
|
||||||
const { items, offset, limit, total } = action.payload;
|
const { items, offset, limit, total } = action.payload;
|
||||||
|
imagesAdapter.upsertMany(state, items);
|
||||||
|
|
||||||
|
if (!categories?.includes('general') || boardId) {
|
||||||
|
// need to skip updating the total images count if the images recieved were for a specific board
|
||||||
|
// TODO: this doesn't work when on the Asset tab/category...
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
state.offset = offset;
|
state.offset = offset;
|
||||||
state.limit = limit;
|
state.limit = limit;
|
||||||
state.total = total;
|
state.total = total;
|
||||||
imagesAdapter.upsertMany(state, items);
|
|
||||||
});
|
});
|
||||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||||
// Image deleted
|
// Image deleted
|
||||||
@ -154,3 +163,16 @@ export const selectFilteredImagesIds = createSelector(
|
|||||||
.map((i) => i.image_name);
|
.map((i) => i.image_name);
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// export const selectImageById = createSelector(
|
||||||
|
// (state: RootState, imageId) => state,
|
||||||
|
// (state) => {
|
||||||
|
// const {
|
||||||
|
// images: { categories },
|
||||||
|
// } = state;
|
||||||
|
|
||||||
|
// return selectImagesAll(state)
|
||||||
|
// .filter((i) => categories.includes(i.image_category))
|
||||||
|
// .map((i) => i.image_name);
|
||||||
|
// }
|
||||||
|
// );
|
||||||
|
@ -11,6 +11,8 @@ import { FieldComponentProps } from './types';
|
|||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||||
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
|
|
||||||
const ImageInputFieldComponent = (
|
const ImageInputFieldComponent = (
|
||||||
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
||||||
@ -19,9 +21,16 @@ const ImageInputFieldComponent = (
|
|||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const {
|
||||||
|
data: image,
|
||||||
|
isLoading,
|
||||||
|
isError,
|
||||||
|
isSuccess,
|
||||||
|
} = useGetImageDTOQuery(field.value ?? skipToken);
|
||||||
|
|
||||||
const handleDrop = useCallback(
|
const handleDrop = useCallback(
|
||||||
(droppedImage: ImageDTO) => {
|
(droppedImage: ImageDTO) => {
|
||||||
if (field.value?.image_name === droppedImage.image_name) {
|
if (field.value === droppedImage.image_name) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,11 +38,11 @@ const ImageInputFieldComponent = (
|
|||||||
fieldValueChanged({
|
fieldValueChanged({
|
||||||
nodeId,
|
nodeId,
|
||||||
fieldName: field.name,
|
fieldName: field.name,
|
||||||
value: droppedImage,
|
value: droppedImage.image_name,
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
[dispatch, field.name, field.value?.image_name, nodeId]
|
[dispatch, field.name, field.value, nodeId]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleReset = useCallback(() => {
|
const handleReset = useCallback(() => {
|
||||||
@ -56,7 +65,7 @@ const ImageInputFieldComponent = (
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<IAIDndImage
|
<IAIDndImage
|
||||||
image={field.value}
|
image={image}
|
||||||
onDrop={handleDrop}
|
onDrop={handleDrop}
|
||||||
onReset={handleReset}
|
onReset={handleReset}
|
||||||
resetIconSize="sm"
|
resetIconSize="sm"
|
||||||
|
@ -1,28 +1,18 @@
|
|||||||
import { Select } from '@chakra-ui/react';
|
import { SelectItem } from '@mantine/core';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import {
|
import {
|
||||||
ModelInputFieldTemplate,
|
ModelInputFieldTemplate,
|
||||||
ModelInputFieldValue,
|
ModelInputFieldValue,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { selectModelsIds } from 'features/system/store/modelSlice';
|
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { ChangeEvent, memo } from 'react';
|
|
||||||
import { FieldComponentProps } from './types';
|
|
||||||
|
|
||||||
const availableModelsSelector = createSelector(
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
[selectModelsIds],
|
import { FieldComponentProps } from './types';
|
||||||
(allModelNames) => {
|
import { forEach, isString } from 'lodash-es';
|
||||||
return { allModelNames };
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
// return map(modelList, (_, name) => name);
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
},
|
import { useTranslation } from 'react-i18next';
|
||||||
{
|
import { useListModelsQuery } from 'services/apiSlice';
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const ModelInputFieldComponent = (
|
const ModelInputFieldComponent = (
|
||||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||||
@ -30,28 +20,82 @@ const ModelInputFieldComponent = (
|
|||||||
const { nodeId, field } = props;
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { allModelNames } = useAppSelector(availableModelsSelector);
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
|
||||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
const data = useMemo(() => {
|
||||||
dispatch(
|
if (!pipelineModels) {
|
||||||
fieldValueChanged({
|
return [];
|
||||||
nodeId,
|
}
|
||||||
fieldName: field.name,
|
|
||||||
value: e.target.value,
|
const data: SelectItem[] = [];
|
||||||
})
|
|
||||||
);
|
forEach(pipelineModels.entities, (model, id) => {
|
||||||
};
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [pipelineModels]);
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
|
||||||
|
[pipelineModels?.entities, pipelineModels?.ids, field.value]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: v,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (field.value && pipelineModels?.ids.includes(field.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModel = pipelineModels?.ids[0];
|
||||||
|
|
||||||
|
if (!isString(firstModel)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
handleValueChanged(firstModel);
|
||||||
|
}, [field.value, handleValueChanged, pipelineModels?.ids]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Select
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model &&
|
||||||
|
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={field.value}
|
||||||
|
placeholder="Pick one"
|
||||||
|
data={data}
|
||||||
onChange={handleValueChanged}
|
onChange={handleValueChanged}
|
||||||
value={field.value || allModelNames[0]}
|
/>
|
||||||
>
|
|
||||||
{allModelNames.map((option) => (
|
|
||||||
<option key={option}>{option}</option>
|
|
||||||
))}
|
|
||||||
</Select>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -101,21 +101,6 @@ const nodesSlice = createSlice({
|
|||||||
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
||||||
state.schema = action.payload;
|
state.schema = action.payload;
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
|
||||||
|
|
||||||
state.nodes.forEach((node) => {
|
|
||||||
forEach(node.data.inputs, (input) => {
|
|
||||||
if (input.type === 'image') {
|
|
||||||
if (input.value?.image_name === image_name) {
|
|
||||||
input.value.image_url = image_url;
|
|
||||||
input.value.thumbnail_url = thumbnail_url;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -214,7 +214,7 @@ export type VaeInputFieldValue = FieldValueBase & {
|
|||||||
|
|
||||||
export type ImageInputFieldValue = FieldValueBase & {
|
export type ImageInputFieldValue = FieldValueBase & {
|
||||||
type: 'image';
|
type: 'image';
|
||||||
value?: ImageDTO;
|
value?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ModelInputFieldValue = FieldValueBase & {
|
export type ModelInputFieldValue = FieldValueBase & {
|
||||||
|
@ -65,15 +65,13 @@ export const addControlNetToLinearGraph = (
|
|||||||
|
|
||||||
if (processedControlImage && processorType !== 'none') {
|
if (processedControlImage && processorType !== 'none') {
|
||||||
// We've already processed the image in the app, so we can just use the processed image
|
// We've already processed the image in the app, so we can just use the processed image
|
||||||
const { image_name } = processedControlImage;
|
|
||||||
controlNetNode.image = {
|
controlNetNode.image = {
|
||||||
image_name,
|
image_name: processedControlImage,
|
||||||
};
|
};
|
||||||
} else if (controlImage) {
|
} else if (controlImage) {
|
||||||
// The control image is preprocessed
|
// The control image is preprocessed
|
||||||
const { image_name } = controlImage;
|
|
||||||
controlNetNode.image = {
|
controlNetNode.image = {
|
||||||
image_name,
|
image_name: controlImage,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
||||||
|
@ -23,6 +23,7 @@ import {
|
|||||||
} from './constants';
|
} from './constants';
|
||||||
import { set } from 'lodash-es';
|
import { set } from 'lodash-es';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -49,6 +50,8 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -85,9 +88,9 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -17,6 +17,7 @@ import {
|
|||||||
INPAINT_GRAPH,
|
INPAINT_GRAPH,
|
||||||
INPAINT,
|
INPAINT,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -31,7 +32,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -54,6 +55,8 @@ export const buildCanvasInpaintGraph = (
|
|||||||
// We may need to set the inpaint width and height to scale the image
|
// We may need to set the inpaint width and height to scale the image
|
||||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: INPAINT_GRAPH,
|
id: INPAINT_GRAPH,
|
||||||
nodes: {
|
nodes: {
|
||||||
@ -99,9 +102,9 @@ export const buildCanvasInpaintGraph = (
|
|||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
[RANGE_OF_SIZE]: {
|
||||||
type: 'range_of_size',
|
type: 'range_of_size',
|
||||||
|
@ -14,6 +14,7 @@ import {
|
|||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
@ -24,7 +25,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -36,6 +37,8 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
// The bounding box determines width and height, not the width and height params
|
// The bounding box determines width and height, not the width and height params
|
||||||
const { width, height } = state.canvas.boundingBoxDimensions;
|
const { width, height } = state.canvas.boundingBoxDimensions;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -80,9 +83,9 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -22,6 +22,7 @@ import {
|
|||||||
} from './constants';
|
} from './constants';
|
||||||
import { set } from 'lodash-es';
|
import { set } from 'lodash-es';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -62,6 +63,8 @@ export const buildLinearImageToImageGraph = (
|
|||||||
throw new Error('No initial image found in state');
|
throw new Error('No initial image found in state');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
const graph: NonNullableGraph = {
|
const graph: NonNullableGraph = {
|
||||||
id: IMAGE_TO_IMAGE_GRAPH,
|
id: IMAGE_TO_IMAGE_GRAPH,
|
||||||
@ -89,9 +92,9 @@ export const buildLinearImageToImageGraph = (
|
|||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
@ -274,7 +277,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
id: RESIZE,
|
id: RESIZE,
|
||||||
type: 'img_resize',
|
type: 'img_resize',
|
||||||
image: {
|
image: {
|
||||||
image_name: initialImage.image_name,
|
image_name: initialImage.imageName,
|
||||||
},
|
},
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
width,
|
width,
|
||||||
@ -311,7 +314,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
} else {
|
} else {
|
||||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||||
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
||||||
image_name: initialImage.image_name,
|
image_name: initialImage.imageName,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Pass the image's dimensions to the `NOISE` node
|
// Pass the image's dimensions to the `NOISE` node
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api';
|
import {
|
||||||
|
BaseModelType,
|
||||||
|
RandomIntInvocation,
|
||||||
|
RangeOfSizeInvocation,
|
||||||
|
} from 'services/api';
|
||||||
import {
|
import {
|
||||||
ITERATE,
|
ITERATE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
@ -14,6 +18,7 @@ import {
|
|||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
type TextToImageGraphOverrides = {
|
type TextToImageGraphOverrides = {
|
||||||
width: number;
|
width: number;
|
||||||
@ -27,7 +32,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
negativePrompt,
|
negativePrompt,
|
||||||
model: model_name,
|
model: modelId,
|
||||||
cfgScale: cfg_scale,
|
cfgScale: cfg_scale,
|
||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
@ -38,6 +43,8 @@ export const buildLinearTextToImageGraph = (
|
|||||||
shouldRandomizeSeed,
|
shouldRandomizeSeed,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
@ -82,9 +89,9 @@ export const buildLinearTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[MODEL_LOADER]: {
|
||||||
type: 'sd1_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: MODEL_LOADER,
|
||||||
model_name,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import { Graph } from 'services/api';
|
import { Graph } from 'services/api';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { cloneDeep, forEach, omit, reduce, values } from 'lodash-es';
|
import { cloneDeep, omit, reduce } from 'lodash-es';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { InputFieldValue } from 'features/nodes/types/types';
|
import { InputFieldValue } from 'features/nodes/types/types';
|
||||||
import { AnyInvocation } from 'services/events/types';
|
import { AnyInvocation } from 'services/events/types';
|
||||||
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* We need to do special handling for some fields
|
* We need to do special handling for some fields
|
||||||
@ -24,6 +25,12 @@ export const parseFieldValue = (field: InputFieldValue) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (field.type === 'model') {
|
||||||
|
if (field.value) {
|
||||||
|
return modelIdToPipelineModelField(field.value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return field.value;
|
return field.value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ export const NOISE = 'noise';
|
|||||||
export const RANDOM_INT = 'rand_int';
|
export const RANDOM_INT = 'rand_int';
|
||||||
export const RANGE_OF_SIZE = 'range_of_size';
|
export const RANGE_OF_SIZE = 'range_of_size';
|
||||||
export const ITERATE = 'iterate';
|
export const ITERATE = 'iterate';
|
||||||
export const MODEL_LOADER = 'model_loader';
|
export const MODEL_LOADER = 'pipeline_model_loader';
|
||||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||||
export const RESIZE = 'resize_image';
|
export const RESIZE = 'resize_image';
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
import { BaseModelType, PipelineModelField } from 'services/api';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Crudely converts a model id to a pipeline model field
|
||||||
|
* TODO: Make better
|
||||||
|
*/
|
||||||
|
export const modelIdToPipelineModelField = (
|
||||||
|
modelId: string
|
||||||
|
): PipelineModelField => {
|
||||||
|
const [base_model, model_type, model_name] = modelId.split('/');
|
||||||
|
|
||||||
|
const field: PipelineModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
@ -57,7 +57,7 @@ export const buildImg2ImgNode = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
imageToImageNode.image = {
|
imageToImageNode.image = {
|
||||||
image_name: initialImage.image_name,
|
image_name: initialImage.imageName,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import ParamScheduler from './ParamScheduler';
|
|||||||
const ParamSchedulerAndModel = () => {
|
const ParamSchedulerAndModel = () => {
|
||||||
return (
|
return (
|
||||||
<Flex gap={3} w="full">
|
<Flex gap={3} w="full">
|
||||||
<Box w="20rem">
|
<Box w="25rem">
|
||||||
<ParamScheduler />
|
<ParamScheduler />
|
||||||
</Box>
|
</Box>
|
||||||
<Box w="full">
|
<Box w="full">
|
||||||
|
@ -10,7 +10,9 @@ import { generationSelector } from 'features/parameters/store/generationSelector
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIDndImage from 'common/components/IAIDndImage';
|
import IAIDndImage from 'common/components/IAIDndImage';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||||
|
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||||
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[generationSelector],
|
[generationSelector],
|
||||||
@ -27,14 +29,21 @@ const InitialImagePreview = () => {
|
|||||||
const { initialImage } = useAppSelector(selector);
|
const { initialImage } = useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const {
|
||||||
|
data: image,
|
||||||
|
isLoading,
|
||||||
|
isError,
|
||||||
|
isSuccess,
|
||||||
|
} = useGetImageDTOQuery(initialImage?.imageName ?? skipToken);
|
||||||
|
|
||||||
const handleDrop = useCallback(
|
const handleDrop = useCallback(
|
||||||
(droppedImage: ImageDTO) => {
|
(droppedImage: ImageDTO) => {
|
||||||
if (droppedImage.image_name === initialImage?.image_name) {
|
if (droppedImage.image_name === initialImage?.imageName) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(initialImageChanged(droppedImage));
|
dispatch(initialImageChanged(droppedImage));
|
||||||
},
|
},
|
||||||
[dispatch, initialImage?.image_name]
|
[dispatch, initialImage]
|
||||||
);
|
);
|
||||||
|
|
||||||
const handleReset = useCallback(() => {
|
const handleReset = useCallback(() => {
|
||||||
@ -53,10 +62,10 @@ const InitialImagePreview = () => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<IAIDndImage
|
<IAIDndImage
|
||||||
image={initialImage}
|
image={image}
|
||||||
onDrop={handleDrop}
|
onDrop={handleDrop}
|
||||||
onReset={handleReset}
|
onReset={handleReset}
|
||||||
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
|
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
|
||||||
postUploadAction={{ type: 'SET_INITIAL_IMAGE' }}
|
postUploadAction={{ type: 'SET_INITIAL_IMAGE' }}
|
||||||
withResetIcon
|
withResetIcon
|
||||||
/>
|
/>
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { clamp, sortBy } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { ImageDTO } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { imageUrlsReceived } from 'services/thunks/image';
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import {
|
import {
|
||||||
CfgScaleParam,
|
CfgScaleParam,
|
||||||
HeightParam,
|
HeightParam,
|
||||||
@ -17,14 +16,13 @@ import {
|
|||||||
StrengthParam,
|
StrengthParam,
|
||||||
WidthParam,
|
WidthParam,
|
||||||
} from './parameterZodSchemas';
|
} from './parameterZodSchemas';
|
||||||
import { DEFAULT_SCHEDULER_NAME } from 'app/constants';
|
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
cfgScale: CfgScaleParam;
|
cfgScale: CfgScaleParam;
|
||||||
height: HeightParam;
|
height: HeightParam;
|
||||||
img2imgStrength: StrengthParam;
|
img2imgStrength: StrengthParam;
|
||||||
infillMethod: string;
|
infillMethod: string;
|
||||||
initialImage?: ImageDTO;
|
initialImage?: { imageName: string; width: number; height: number };
|
||||||
iterations: number;
|
iterations: number;
|
||||||
perlin: number;
|
perlin: number;
|
||||||
positivePrompt: PositivePromptParam;
|
positivePrompt: PositivePromptParam;
|
||||||
@ -212,35 +210,20 @@ export const generationSlice = createSlice({
|
|||||||
state.shouldUseNoiseSettings = action.payload;
|
state.shouldUseNoiseSettings = action.payload;
|
||||||
},
|
},
|
||||||
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
||||||
state.initialImage = action.payload;
|
const { image_name, width, height } = action.payload;
|
||||||
|
state.initialImage = { imageName: image_name, width, height };
|
||||||
},
|
},
|
||||||
modelSelected: (state, action: PayloadAction<string>) => {
|
modelSelected: (state, action: PayloadAction<string>) => {
|
||||||
state.model = action.payload;
|
state.model = action.payload;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
|
||||||
if (!state.model) {
|
|
||||||
const firstModel = sortBy(action.payload, 'name')[0];
|
|
||||||
state.model = firstModel.name;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.addCase(configChanged, (state, action) => {
|
builder.addCase(configChanged, (state, action) => {
|
||||||
const defaultModel = action.payload.sd?.defaultModel;
|
const defaultModel = action.payload.sd?.defaultModel;
|
||||||
if (defaultModel && !state.model) {
|
if (defaultModel && !state.model) {
|
||||||
state.model = defaultModel;
|
state.model = defaultModel;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
|
||||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
|
||||||
|
|
||||||
if (state.initialImage?.image_name === image_name) {
|
|
||||||
state.initialImage.image_url = image_url;
|
|
||||||
state.initialImage.thumbnail_url = thumbnail_url;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -154,3 +154,17 @@ export type StrengthParam = z.infer<typeof zStrength>;
|
|||||||
*/
|
*/
|
||||||
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
||||||
zStrength.safeParse(val).success;
|
zStrength.safeParse(val).success;
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Zod schema for BaseModelType
|
||||||
|
// */
|
||||||
|
// export const zBaseModelType = z.enum(['sd-1', 'sd-2']);
|
||||||
|
// /**
|
||||||
|
// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI.
|
||||||
|
// */
|
||||||
|
// export type BaseModelType = z.infer<typeof zBaseModelType>;
|
||||||
|
// /**
|
||||||
|
// * Validates/type-guards a value as a base model type
|
||||||
|
// */
|
||||||
|
// export const isValidBaseModelType = (val: unknown): val is BaseModelType =>
|
||||||
|
// zBaseModelType.safeParse(val).success;
|
||||||
|
@ -1,44 +1,59 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
import { isEqual } from 'lodash-es';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineSelect, {
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
IAISelectDataType,
|
|
||||||
} from 'common/components/IAIMantineSelect';
|
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
|
||||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
import { forEach, isString } from 'lodash-es';
|
||||||
[(state: RootState) => state, generationSelector],
|
import { SelectItem } from '@mantine/core';
|
||||||
(state, generation) => {
|
import { RootState } from 'app/store/store';
|
||||||
const selectedModel = selectModelsById(state, generation.model);
|
import { useListModelsQuery } from 'services/apiSlice';
|
||||||
|
|
||||||
const modelData = selectModelsAll(state)
|
export const MODEL_TYPE_MAP = {
|
||||||
.map<IAISelectDataType>((m) => ({
|
'sd-1': 'Stable Diffusion 1.x',
|
||||||
value: m.name,
|
'sd-2': 'Stable Diffusion 2.x',
|
||||||
label: m.name,
|
};
|
||||||
}))
|
|
||||||
.sort((a, b) => a.label.localeCompare(b.label));
|
|
||||||
return {
|
|
||||||
selectedModel,
|
|
||||||
modelData,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
{
|
|
||||||
memoizeOptions: {
|
|
||||||
resultEqualityCheck: isEqual,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const ModelSelect = () => {
|
const ModelSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { selectedModel, modelData } = useAppSelector(selector);
|
|
||||||
|
const selectedModelId = useAppSelector(
|
||||||
|
(state: RootState) => state.generation.model
|
||||||
|
);
|
||||||
|
|
||||||
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
|
model_type: 'pipeline',
|
||||||
|
});
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!pipelineModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(pipelineModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [pipelineModels]);
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => pipelineModels?.entities[selectedModelId],
|
||||||
|
[pipelineModels?.entities, selectedModelId]
|
||||||
|
);
|
||||||
|
|
||||||
const handleChangeModel = useCallback(
|
const handleChangeModel = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
if (!v) {
|
if (!v) {
|
||||||
@ -49,13 +64,27 @@ const ModelSelect = () => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedModelId && pipelineModels?.ids.includes(selectedModelId)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModel = pipelineModels?.ids[0];
|
||||||
|
|
||||||
|
if (!isString(firstModel)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
handleChangeModel(firstModel);
|
||||||
|
}, [handleChangeModel, pipelineModels?.ids, selectedModelId]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAIMantineSelect
|
<IAIMantineSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={t('modelManager.model')}
|
label={t('modelManager.model')}
|
||||||
value={selectedModel?.name ?? ''}
|
value={selectedModelId}
|
||||||
placeholder="Pick one"
|
placeholder="Pick one"
|
||||||
data={modelData}
|
data={data}
|
||||||
onChange={handleChangeModel}
|
onChange={handleChangeModel}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas';
|
||||||
@ -16,6 +15,7 @@ const data = map(SCHEDULER_NAMES, (s) => ({
|
|||||||
|
|
||||||
export default function SettingsSchedulers() {
|
export default function SettingsSchedulers() {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const enabledSchedulers = useAppSelector(
|
const enabledSchedulers = useAppSelector(
|
||||||
|
@ -7,13 +7,12 @@ import { systemSelector } from '../store/systemSelectors';
|
|||||||
const isApplicationReadySelector = createSelector(
|
const isApplicationReadySelector = createSelector(
|
||||||
[systemSelector, configSelector],
|
[systemSelector, configSelector],
|
||||||
(system, config) => {
|
(system, config) => {
|
||||||
const { wereModelsReceived, wasSchemaParsed } = system;
|
const { wasSchemaParsed } = system;
|
||||||
|
|
||||||
const { disabledTabs } = config;
|
const { disabledTabs } = config;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
disabledTabs,
|
disabledTabs,
|
||||||
wereModelsReceived,
|
|
||||||
wasSchemaParsed,
|
wasSchemaParsed,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -23,21 +22,17 @@ const isApplicationReadySelector = createSelector(
|
|||||||
* Checks if the application is ready to be used, i.e. if the initial startup process is finished.
|
* Checks if the application is ready to be used, i.e. if the initial startup process is finished.
|
||||||
*/
|
*/
|
||||||
export const useIsApplicationReady = () => {
|
export const useIsApplicationReady = () => {
|
||||||
const { disabledTabs, wereModelsReceived, wasSchemaParsed } = useAppSelector(
|
const { disabledTabs, wasSchemaParsed } = useAppSelector(
|
||||||
isApplicationReadySelector
|
isApplicationReadySelector
|
||||||
);
|
);
|
||||||
|
|
||||||
const isApplicationReady = useMemo(() => {
|
const isApplicationReady = useMemo(() => {
|
||||||
if (!wereModelsReceived) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
|
if (!disabledTabs.includes('nodes') && !wasSchemaParsed) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}, [disabledTabs, wereModelsReceived, wasSchemaParsed]);
|
}, [disabledTabs, wasSchemaParsed]);
|
||||||
|
|
||||||
return isApplicationReady;
|
return isApplicationReady;
|
||||||
};
|
};
|
||||||
|
@ -1,3 +0,0 @@
|
|||||||
import { RootState } from 'app/store/store';
|
|
||||||
|
|
||||||
export const modelSelector = (state: RootState) => state.models;
|
|
@ -1,47 +0,0 @@
|
|||||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { CkptModelInfo, DiffusersModelInfo } from 'services/api';
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
|
|
||||||
export type Model = (CkptModelInfo | DiffusersModelInfo) & {
|
|
||||||
name: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const modelsAdapter = createEntityAdapter<Model>({
|
|
||||||
selectId: (model) => model.name,
|
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
|
||||||
});
|
|
||||||
|
|
||||||
export const initialModelsState = modelsAdapter.getInitialState();
|
|
||||||
|
|
||||||
export type ModelsState = typeof initialModelsState;
|
|
||||||
|
|
||||||
export const modelsSlice = createSlice({
|
|
||||||
name: 'models',
|
|
||||||
initialState: initialModelsState,
|
|
||||||
reducers: {
|
|
||||||
modelAdded: modelsAdapter.upsertOne,
|
|
||||||
},
|
|
||||||
extraReducers(builder) {
|
|
||||||
/**
|
|
||||||
* Received Models - FULFILLED
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedModels.fulfilled, (state, action) => {
|
|
||||||
const models = action.payload;
|
|
||||||
modelsAdapter.setAll(state, models);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export const {
|
|
||||||
selectAll: selectModelsAll,
|
|
||||||
selectById: selectModelsById,
|
|
||||||
selectEntities: selectModelsEntities,
|
|
||||||
selectIds: selectModelsIds,
|
|
||||||
selectTotal: selectModelsTotal,
|
|
||||||
} = modelsAdapter.getSelectors<RootState>((state) => state.models);
|
|
||||||
|
|
||||||
export const { modelAdded } = modelsSlice.actions;
|
|
||||||
|
|
||||||
export default modelsSlice.reducer;
|
|
@ -1,6 +0,0 @@
|
|||||||
import { ModelsState } from './modelSlice';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Models slice persist denylist
|
|
||||||
*/
|
|
||||||
export const modelsPersistDenylist: (keyof ModelsState)[] = ['entities', 'ids'];
|
|
@ -1,20 +1,12 @@
|
|||||||
import { UseToastOptions } from '@chakra-ui/react';
|
import { UseToastOptions } from '@chakra-ui/react';
|
||||||
import { PayloadAction } from '@reduxjs/toolkit';
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/types/invokeai';
|
||||||
|
|
||||||
import { ProgressImage } from 'services/events/types';
|
|
||||||
import { makeToast } from '../../../app/components/Toaster';
|
|
||||||
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
|
||||||
import { receivedModels } from 'services/thunks/model';
|
|
||||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { LogLevelName } from 'roarr';
|
|
||||||
import { InvokeLogLevel } from 'app/logging/useLogger';
|
import { InvokeLogLevel } from 'app/logging/useLogger';
|
||||||
import { TFuncKey } from 'i18next';
|
|
||||||
import { t } from 'i18next';
|
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { LANGUAGES } from '../components/LanguagePicker';
|
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { TFuncKey, t } from 'i18next';
|
||||||
|
import { LogLevelName } from 'roarr';
|
||||||
import {
|
import {
|
||||||
appSocketConnected,
|
appSocketConnected,
|
||||||
appSocketDisconnected,
|
appSocketDisconnected,
|
||||||
@ -26,6 +18,11 @@ import {
|
|||||||
appSocketSubscribed,
|
appSocketSubscribed,
|
||||||
appSocketUnsubscribed,
|
appSocketUnsubscribed,
|
||||||
} from 'services/events/actions';
|
} from 'services/events/actions';
|
||||||
|
import { ProgressImage } from 'services/events/types';
|
||||||
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
|
import { isAnySessionRejected, sessionCanceled } from 'services/thunks/session';
|
||||||
|
import { makeToast } from '../../../app/components/Toaster';
|
||||||
|
import { LANGUAGES } from '../components/LanguagePicker';
|
||||||
|
|
||||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||||
|
|
||||||
@ -95,6 +92,7 @@ export interface SystemState {
|
|||||||
shouldAntialiasProgressImage: boolean;
|
shouldAntialiasProgressImage: boolean;
|
||||||
language: keyof typeof LANGUAGES;
|
language: keyof typeof LANGUAGES;
|
||||||
isUploading: boolean;
|
isUploading: boolean;
|
||||||
|
boardIdToAddTo?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const initialSystemState: SystemState = {
|
export const initialSystemState: SystemState = {
|
||||||
@ -225,6 +223,7 @@ export const systemSlice = createSlice({
|
|||||||
*/
|
*/
|
||||||
builder.addCase(appSocketSubscribed, (state, action) => {
|
builder.addCase(appSocketSubscribed, (state, action) => {
|
||||||
state.sessionId = action.payload.sessionId;
|
state.sessionId = action.payload.sessionId;
|
||||||
|
state.boardIdToAddTo = action.payload.boardId;
|
||||||
state.canceledSession = '';
|
state.canceledSession = '';
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -233,6 +232,7 @@ export const systemSlice = createSlice({
|
|||||||
*/
|
*/
|
||||||
builder.addCase(appSocketUnsubscribed, (state) => {
|
builder.addCase(appSocketUnsubscribed, (state) => {
|
||||||
state.sessionId = null;
|
state.sessionId = null;
|
||||||
|
state.boardIdToAddTo = undefined;
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -376,13 +376,6 @@ export const systemSlice = createSlice({
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
|
||||||
* Received available models from the backend
|
|
||||||
*/
|
|
||||||
builder.addCase(receivedModels.fulfilled, (state) => {
|
|
||||||
state.wereModelsReceived = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* OpenAPI schema was parsed
|
* OpenAPI schema was parsed
|
||||||
*/
|
*/
|
||||||
|
@ -8,6 +8,10 @@ export type { OpenAPIConfig } from './core/OpenAPI';
|
|||||||
|
|
||||||
export type { AddInvocation } from './models/AddInvocation';
|
export type { AddInvocation } from './models/AddInvocation';
|
||||||
export type { BaseModelType } from './models/BaseModelType';
|
export type { BaseModelType } from './models/BaseModelType';
|
||||||
|
export type { BoardChanges } from './models/BoardChanges';
|
||||||
|
export type { BoardDTO } from './models/BoardDTO';
|
||||||
|
export type { Body_create_board_image } from './models/Body_create_board_image';
|
||||||
|
export type { Body_remove_board_image } from './models/Body_remove_board_image';
|
||||||
export type { Body_upload_image } from './models/Body_upload_image';
|
export type { Body_upload_image } from './models/Body_upload_image';
|
||||||
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
|
export type { CannyImageProcessorInvocation } from './models/CannyImageProcessorInvocation';
|
||||||
export type { CkptModelInfo } from './models/CkptModelInfo';
|
export type { CkptModelInfo } from './models/CkptModelInfo';
|
||||||
@ -21,6 +25,8 @@ export type { ConditioningField } from './models/ConditioningField';
|
|||||||
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
|
export type { ContentShuffleImageProcessorInvocation } from './models/ContentShuffleImageProcessorInvocation';
|
||||||
export type { ControlField } from './models/ControlField';
|
export type { ControlField } from './models/ControlField';
|
||||||
export type { ControlNetInvocation } from './models/ControlNetInvocation';
|
export type { ControlNetInvocation } from './models/ControlNetInvocation';
|
||||||
|
export type { ControlNetModelConfig } from './models/ControlNetModelConfig';
|
||||||
|
export type { ControlNetModelFormat } from './models/ControlNetModelFormat';
|
||||||
export type { ControlOutput } from './models/ControlOutput';
|
export type { ControlOutput } from './models/ControlOutput';
|
||||||
export type { CreateModelRequest } from './models/CreateModelRequest';
|
export type { CreateModelRequest } from './models/CreateModelRequest';
|
||||||
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
export type { CvInpaintInvocation } from './models/CvInpaintInvocation';
|
||||||
@ -63,14 +69,6 @@ export type { InfillTileInvocation } from './models/InfillTileInvocation';
|
|||||||
export type { InpaintInvocation } from './models/InpaintInvocation';
|
export type { InpaintInvocation } from './models/InpaintInvocation';
|
||||||
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
||||||
export type { IntOutput } from './models/IntOutput';
|
export type { IntOutput } from './models/IntOutput';
|
||||||
export type { invokeai__backend__model_management__models__controlnet__ControlNetModel__Config } from './models/invokeai__backend__model_management__models__controlnet__ControlNetModel__Config';
|
|
||||||
export type { invokeai__backend__model_management__models__lora__LoRAModel__Config } from './models/invokeai__backend__model_management__models__lora__LoRAModel__Config';
|
|
||||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__CheckpointConfig';
|
|
||||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion1Model__DiffusersConfig';
|
|
||||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__CheckpointConfig';
|
|
||||||
export type { invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig } from './models/invokeai__backend__model_management__models__stable_diffusion__StableDiffusion2Model__DiffusersConfig';
|
|
||||||
export type { invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config } from './models/invokeai__backend__model_management__models__textual_inversion__TextualInversionModel__Config';
|
|
||||||
export type { invokeai__backend__model_management__models__vae__VaeModel__Config } from './models/invokeai__backend__model_management__models__vae__VaeModel__Config';
|
|
||||||
export type { IterateInvocation } from './models/IterateInvocation';
|
export type { IterateInvocation } from './models/IterateInvocation';
|
||||||
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
||||||
export type { LatentsField } from './models/LatentsField';
|
export type { LatentsField } from './models/LatentsField';
|
||||||
@ -83,6 +81,8 @@ export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
|||||||
export type { LoraInfo } from './models/LoraInfo';
|
export type { LoraInfo } from './models/LoraInfo';
|
||||||
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
|
export type { LoraLoaderInvocation } from './models/LoraLoaderInvocation';
|
||||||
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
|
export type { LoraLoaderOutput } from './models/LoraLoaderOutput';
|
||||||
|
export type { LoRAModelConfig } from './models/LoRAModelConfig';
|
||||||
|
export type { LoRAModelFormat } from './models/LoRAModelFormat';
|
||||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||||
export type { MaskOutput } from './models/MaskOutput';
|
export type { MaskOutput } from './models/MaskOutput';
|
||||||
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
export type { MediapipeFaceProcessorInvocation } from './models/MediapipeFaceProcessorInvocation';
|
||||||
@ -98,12 +98,15 @@ export type { MultiplyInvocation } from './models/MultiplyInvocation';
|
|||||||
export type { NoiseInvocation } from './models/NoiseInvocation';
|
export type { NoiseInvocation } from './models/NoiseInvocation';
|
||||||
export type { NoiseOutput } from './models/NoiseOutput';
|
export type { NoiseOutput } from './models/NoiseOutput';
|
||||||
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
|
export type { NormalbaeImageProcessorInvocation } from './models/NormalbaeImageProcessorInvocation';
|
||||||
|
export type { OffsetPaginatedResults_BoardDTO_ } from './models/OffsetPaginatedResults_BoardDTO_';
|
||||||
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
|
export type { OffsetPaginatedResults_ImageDTO_ } from './models/OffsetPaginatedResults_ImageDTO_';
|
||||||
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
|
export type { OpenposeImageProcessorInvocation } from './models/OpenposeImageProcessorInvocation';
|
||||||
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
||||||
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
|
export type { ParamFloatInvocation } from './models/ParamFloatInvocation';
|
||||||
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
||||||
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
|
export type { PidiImageProcessorInvocation } from './models/PidiImageProcessorInvocation';
|
||||||
|
export type { PipelineModelField } from './models/PipelineModelField';
|
||||||
|
export type { PipelineModelLoaderInvocation } from './models/PipelineModelLoaderInvocation';
|
||||||
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
|
export type { PromptCollectionOutput } from './models/PromptCollectionOutput';
|
||||||
export type { PromptOutput } from './models/PromptOutput';
|
export type { PromptOutput } from './models/PromptOutput';
|
||||||
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
||||||
@ -115,20 +118,28 @@ export type { ResourceOrigin } from './models/ResourceOrigin';
|
|||||||
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
||||||
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
||||||
export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
|
export type { SchedulerPredictionType } from './models/SchedulerPredictionType';
|
||||||
export type { SD1ModelLoaderInvocation } from './models/SD1ModelLoaderInvocation';
|
|
||||||
export type { SD2ModelLoaderInvocation } from './models/SD2ModelLoaderInvocation';
|
|
||||||
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
export type { ShowImageInvocation } from './models/ShowImageInvocation';
|
||||||
|
export type { StableDiffusion1ModelCheckpointConfig } from './models/StableDiffusion1ModelCheckpointConfig';
|
||||||
|
export type { StableDiffusion1ModelDiffusersConfig } from './models/StableDiffusion1ModelDiffusersConfig';
|
||||||
|
export type { StableDiffusion1ModelFormat } from './models/StableDiffusion1ModelFormat';
|
||||||
|
export type { StableDiffusion2ModelCheckpointConfig } from './models/StableDiffusion2ModelCheckpointConfig';
|
||||||
|
export type { StableDiffusion2ModelDiffusersConfig } from './models/StableDiffusion2ModelDiffusersConfig';
|
||||||
|
export type { StableDiffusion2ModelFormat } from './models/StableDiffusion2ModelFormat';
|
||||||
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
export type { StepParamEasingInvocation } from './models/StepParamEasingInvocation';
|
||||||
export type { SubModelType } from './models/SubModelType';
|
export type { SubModelType } from './models/SubModelType';
|
||||||
export type { SubtractInvocation } from './models/SubtractInvocation';
|
export type { SubtractInvocation } from './models/SubtractInvocation';
|
||||||
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
export type { TextToLatentsInvocation } from './models/TextToLatentsInvocation';
|
||||||
|
export type { TextualInversionModelConfig } from './models/TextualInversionModelConfig';
|
||||||
export type { UNetField } from './models/UNetField';
|
export type { UNetField } from './models/UNetField';
|
||||||
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
||||||
export type { VaeField } from './models/VaeField';
|
export type { VaeField } from './models/VaeField';
|
||||||
|
export type { VaeModelConfig } from './models/VaeModelConfig';
|
||||||
|
export type { VaeModelFormat } from './models/VaeModelFormat';
|
||||||
export type { VaeRepo } from './models/VaeRepo';
|
export type { VaeRepo } from './models/VaeRepo';
|
||||||
export type { ValidationError } from './models/ValidationError';
|
export type { ValidationError } from './models/ValidationError';
|
||||||
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
export type { ZoeDepthImageProcessorInvocation } from './models/ZoeDepthImageProcessorInvocation';
|
||||||
|
|
||||||
|
export { BoardsService } from './services/BoardsService';
|
||||||
export { ImagesService } from './services/ImagesService';
|
export { ImagesService } from './services/ImagesService';
|
||||||
export { ModelsService } from './services/ModelsService';
|
export { ModelsService } from './services/ModelsService';
|
||||||
export { SessionsService } from './services/SessionsService';
|
export { SessionsService } from './services/SessionsService';
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
export type BoardChanges = {
|
||||||
|
/**
|
||||||
|
* The board's new name.
|
||||||
|
*/
|
||||||
|
board_name?: string;
|
||||||
|
/**
|
||||||
|
* The name of the board's new cover image.
|
||||||
|
*/
|
||||||
|
cover_image_name?: string;
|
||||||
|
};
|
||||||
|
|
38
invokeai/frontend/web/src/services/api/models/BoardDTO.ts
Normal file
38
invokeai/frontend/web/src/services/api/models/BoardDTO.ts
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deserialized board record with cover image URL and image count.
|
||||||
|
*/
|
||||||
|
export type BoardDTO = {
|
||||||
|
/**
|
||||||
|
* The unique ID of the board.
|
||||||
|
*/
|
||||||
|
board_id: string;
|
||||||
|
/**
|
||||||
|
* The name of the board.
|
||||||
|
*/
|
||||||
|
board_name: string;
|
||||||
|
/**
|
||||||
|
* The created timestamp of the board.
|
||||||
|
*/
|
||||||
|
created_at: string;
|
||||||
|
/**
|
||||||
|
* The updated timestamp of the board.
|
||||||
|
*/
|
||||||
|
updated_at: string;
|
||||||
|
/**
|
||||||
|
* The deleted timestamp of the board.
|
||||||
|
*/
|
||||||
|
deleted_at?: string;
|
||||||
|
/**
|
||||||
|
* The name of the board's cover image.
|
||||||
|
*/
|
||||||
|
cover_image_name?: string;
|
||||||
|
/**
|
||||||
|
* The number of images in the board.
|
||||||
|
*/
|
||||||
|
image_count: number;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,15 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
export type Body_create_board_image = {
|
||||||
|
/**
|
||||||
|
* The id of the board to add to
|
||||||
|
*/
|
||||||
|
board_id: string;
|
||||||
|
/**
|
||||||
|
* The name of the image to add
|
||||||
|
*/
|
||||||
|
image_name: string;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,15 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
export type Body_remove_board_image = {
|
||||||
|
/**
|
||||||
|
* The id of the board
|
||||||
|
*/
|
||||||
|
board_id: string;
|
||||||
|
/**
|
||||||
|
* The name of the image to remove
|
||||||
|
*/
|
||||||
|
image_name: string;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,18 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { BaseModelType } from './BaseModelType';
|
||||||
|
import type { ControlNetModelFormat } from './ControlNetModelFormat';
|
||||||
|
import type { ModelError } from './ModelError';
|
||||||
|
|
||||||
|
export type ControlNetModelConfig = {
|
||||||
|
name: string;
|
||||||
|
base_model: BaseModelType;
|
||||||
|
type: 'controlnet';
|
||||||
|
path: string;
|
||||||
|
description?: string;
|
||||||
|
model_format: ControlNetModelFormat;
|
||||||
|
error?: ModelError;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,8 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An enumeration.
|
||||||
|
*/
|
||||||
|
export type ControlNetModelFormat = 'checkpoint' | 'diffusers';
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user