mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix/vae_conversion
This commit is contained in:
commit
9140e2c0f2
@ -25,7 +25,7 @@ done
|
||||
|
||||
if [ -z "$PYTHON" ]; then
|
||||
echo "A suitable Python interpreter could not be found"
|
||||
echo "Please install Python 3.9 or higher before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
|
||||
read -p "Press any key to exit"
|
||||
exit -1
|
||||
fi
|
||||
|
@ -2,8 +2,17 @@
|
||||
|
||||
from logging import Logger
|
||||
import os
|
||||
from invokeai.app.services.board_image_record_storage import (
|
||||
SqliteBoardImageRecordStorage,
|
||||
)
|
||||
from invokeai.app.services.board_images import (
|
||||
BoardImagesService,
|
||||
BoardImagesServiceDependencies,
|
||||
)
|
||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
@ -57,7 +66,7 @@ class ApiDependencies:
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = config.db_path
|
||||
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
@ -72,7 +81,32 @@ class ApiDependencies:
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
|
||||
boards = BoardService(
|
||||
services=BoardServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
board_images = BoardImagesService(
|
||||
services=BoardImagesServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
board_record_storage=board_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
url=urls,
|
||||
logger=logger,
|
||||
)
|
||||
)
|
||||
|
||||
images = ImageService(
|
||||
services=ImageServiceDependencies(
|
||||
board_image_record_storage=board_image_record_storage,
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
@ -81,12 +115,15 @@ class ApiDependencies:
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
)
|
||||
|
||||
services = InvocationServices(
|
||||
model_manager=ModelManagerService(config,logger),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
|
69
invokeai/app/api/routers/board_images.py
Normal file
69
invokeai/app/api/routers/board_images.py
Normal file
@ -0,0 +1,69 @@
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||
|
||||
|
||||
@board_images_router.post(
|
||||
"/",
|
||||
operation_id="create_board_image",
|
||||
responses={
|
||||
201: {"description": "The image was added to a board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def create_board_image(
|
||||
board_id: str = Body(description="The id of the board to add to"),
|
||||
image_name: str = Body(description="The name of the image to add"),
|
||||
):
|
||||
"""Creates a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to add to board")
|
||||
|
||||
@board_images_router.delete(
|
||||
"/",
|
||||
operation_id="remove_board_image",
|
||||
responses={
|
||||
201: {"description": "The image was removed from the board successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def remove_board_image(
|
||||
board_id: str = Body(description="The id of the board"),
|
||||
image_name: str = Body(description="The name of the image to remove"),
|
||||
):
|
||||
"""Deletes a board_image"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(board_id=board_id, image_name=image_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
|
||||
@board_images_router.get(
|
||||
"/{board_id}",
|
||||
operation_id="list_board_images",
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_board_images(
|
||||
board_id: str = Path(description="The id of the board"),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of boards per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images for a board"""
|
||||
|
||||
results = ApiDependencies.invoker.services.board_images.get_images_for_board(
|
||||
board_id,
|
||||
)
|
||||
return results
|
||||
|
108
invokeai/app/api/routers/boards.py
Normal file
108
invokeai/app/api/routers/boards.py
Normal file
@ -0,0 +1,108 @@
|
||||
from typing import Optional, Union
|
||||
from fastapi import Body, HTTPException, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from invokeai.app.services.board_record_storage import BoardChanges
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
|
||||
|
||||
|
||||
@boards_router.post(
|
||||
"/",
|
||||
operation_id="create_board",
|
||||
responses={
|
||||
201: {"description": "The board was created successfully"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def create_board(
|
||||
board_name: str = Query(description="The name of the board to create"),
|
||||
) -> BoardDTO:
|
||||
"""Creates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.create(board_name=board_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to create board")
|
||||
|
||||
|
||||
@boards_router.get("/{board_id}", operation_id="get_board", response_model=BoardDTO)
|
||||
async def get_board(
|
||||
board_id: str = Path(description="The id of board to get"),
|
||||
) -> BoardDTO:
|
||||
"""Gets a board"""
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404, detail="Board not found")
|
||||
|
||||
|
||||
@boards_router.patch(
|
||||
"/{board_id}",
|
||||
operation_id="update_board",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The board was updated successfully",
|
||||
},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=BoardDTO,
|
||||
)
|
||||
async def update_board(
|
||||
board_id: str = Path(description="The id of board to update"),
|
||||
changes: BoardChanges = Body(description="The changes to apply to the board"),
|
||||
) -> BoardDTO:
|
||||
"""Updates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(
|
||||
board_id=board_id, changes=changes
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board")
|
||||
async def delete_board(
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
) -> None:
|
||||
"""Deletes a board"""
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
except Exception as e:
|
||||
# TODO: Does this need any exception handling at all?
|
||||
pass
|
||||
|
||||
|
||||
@boards_router.get(
|
||||
"/",
|
||||
operation_id="list_boards",
|
||||
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
|
||||
)
|
||||
async def list_boards(
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||
limit: Optional[int] = Query(
|
||||
default=None, description="The number of boards per page"
|
||||
),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
if all:
|
||||
return ApiDependencies.invoker.services.boards.get_all()
|
||||
elif offset is not None and limit is not None:
|
||||
return ApiDependencies.invoker.services.boards.get_many(
|
||||
offset,
|
||||
limit,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid request: Must provide either 'all' or both 'offset' and 'limit'",
|
||||
)
|
@ -221,6 +221,9 @@ async def list_images_with_metadata(
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images"
|
||||
),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None, description="The board id to filter by"
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
@ -232,6 +235,7 @@ async def list_images_with_metadata(
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
@ -7,8 +7,8 @@ from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models import get_all_model_configs
|
||||
MODEL_CONFIGS = Union[tuple(get_all_model_configs())]
|
||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
|
||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
|
||||
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
||||
models: list[MODEL_CONFIGS]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
@ -72,10 +71,10 @@ class ModelsList(BaseModel):
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models(
|
||||
base_model: BaseModelType = Query(
|
||||
base_model: Optional[BaseModelType] = Query(
|
||||
default=None, description="Base model"
|
||||
),
|
||||
model_type: ModelType = Query(
|
||||
model_type: Optional[ModelType] = Query(
|
||||
default=None, description="The type of model to get"
|
||||
),
|
||||
) -> ModelsList:
|
||||
|
@ -24,7 +24,7 @@ logger = InvokeAILogger.getLogger(config=app_config)
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import sessions, models, images
|
||||
from .api.routers import sessions, models, images, boards, board_images
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
|
||||
@ -78,6 +78,10 @@ app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
|
||||
app.include_router(board_images.board_images_router, prefix="/api")
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi():
|
||||
@ -116,6 +120,22 @@ def custom_openapi():
|
||||
|
||||
invoker_schema["output"] = outputs_ref
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
|
||||
if name in openapi_schema["components"]["schemas"]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
# "BaseModelType":{"title":"BaseModelType","description":"An enumeration.","enum":["sd-1","sd-2"],"type":"string"}
|
||||
openapi_schema["components"]["schemas"][name] = dict(
|
||||
title=name,
|
||||
description="An enumeration.",
|
||||
type="string",
|
||||
enum=list(v.value for v in model_config_format_enum),
|
||||
)
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
@ -12,12 +12,19 @@ from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.backend.generator.inpaint import infill_methods
|
||||
|
||||
from ...backend.generator import Img2Img, Inpaint, InvokeAIGenerator, Txt2Img
|
||||
from ...backend.generator import Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.step_callback import stable_diffusion_step_callback
|
||||
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
|
||||
from .image import ImageOutput
|
||||
|
||||
import re
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from .model import UNetField, VaeField
|
||||
from .compel import ConditioningField
|
||||
from contextlib import contextmanager, ExitStack, ContextDecorator
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
DEFAULT_INFILL_METHOD = (
|
||||
@ -25,114 +32,48 @@ DEFAULT_INFILL_METHOD = (
|
||||
)
|
||||
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
from .latent import get_scheduler
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
},
|
||||
},
|
||||
}
|
||||
class OldModelContext(ContextDecorator):
|
||||
model: StableDiffusionGeneratorPipeline
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def __enter__(self):
|
||||
return self.model
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
class OldModelInfo:
|
||||
name: str
|
||||
hash: str
|
||||
context: OldModelContext
|
||||
|
||||
def __init__(self, name: str, hash: str, model: StableDiffusionGeneratorPipeline):
|
||||
self.name = name
|
||||
self.hash = hash
|
||||
self.context = OldModelContext(
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
class InpaintInvocation(BaseInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
control_model: Optional[str] = Field(default=None, description="The control model to use")
|
||||
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
|
||||
# fmt: on
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Handle invalid model parameter
|
||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
||||
|
||||
# loading controlnet image (currently requires pre-processed image)
|
||||
control_image = (
|
||||
None if self.control_image is None
|
||||
else context.services.images.get_pil_image(self.control_image.image_name)
|
||||
)
|
||||
# loading controlnet model
|
||||
if (self.control_model is None or self.control_model==''):
|
||||
control_model = None
|
||||
else:
|
||||
# FIXME: change this to dropdown menu?
|
||||
# FIXME: generalize so don't have to hardcode torch_dtype and device
|
||||
control_model = ControlNetModel.from_pretrained(self.control_model,
|
||||
torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
txt2img = Txt2Img(model, control_model=control_model)
|
||||
outputs = txt2img.generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
control_image=control_image,
|
||||
**self.dict(
|
||||
exclude={"prompt", "control_image" }
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generate_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class ImageToImageInvocation(TextToImageInvocation):
|
||||
"""Generates an image using img2img."""
|
||||
|
||||
type: Literal["img2img"] = "img2img"
|
||||
unet: UNetField = Field(default=None, description="UNet model")
|
||||
vae: VaeField = Field(default=None, description="Vae model")
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, None] = Field(description="The input image")
|
||||
@ -144,72 +85,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||
)
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.dict(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
if self.image is None
|
||||
else context.services.images.get_pil_image(self.image.image_name)
|
||||
)
|
||||
|
||||
if self.fit:
|
||||
image = image.resize((self.width, self.height))
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generator_output = next(outputs)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=generator_output.image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
type: Literal["inpaint"] = "inpaint"
|
||||
|
||||
# Inputs
|
||||
mask: Union[ImageField, None] = Field(description="The mask")
|
||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||
@ -252,6 +127,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
description="The amount by which to replace masked areas with latent noise",
|
||||
)
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@ -265,6 +148,49 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
def get_conditioning(self, context):
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
return (uc, c, extra_conditioning_info)
|
||||
|
||||
@contextmanager
|
||||
def load_model_old_way(self, context, scheduler):
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||
|
||||
#unet = unet_info.context.model
|
||||
#vae = vae_info.context.model
|
||||
|
||||
with ExitStack() as stack:
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with vae_info as vae,\
|
||||
unet_info as unet,\
|
||||
ModelPatcher.apply_lora_unet(unet, loras):
|
||||
|
||||
device = context.services.model_manager.mgr.cache.execution_device
|
||||
dtype = context.services.model_manager.mgr.cache.precision
|
||||
|
||||
pipeline = StableDiffusionGeneratorPipeline(
|
||||
vae=vae,
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
precision="float16" if dtype == torch.float16 else "float32",
|
||||
execution_device=device,
|
||||
)
|
||||
|
||||
yield OldModelInfo(
|
||||
name=self.unet.unet.model_name,
|
||||
hash="<NO-HASH>",
|
||||
model=pipeline,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = (
|
||||
None
|
||||
@ -277,22 +203,28 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
else context.services.images.get_pil_image(self.mask.image_name)
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
model = context.services.model_manager.get_model(self.model,node=self,context=context)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(
|
||||
context.graph_execution_state_id
|
||||
)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
conditioning = self.get_conditioning(context)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
|
||||
with self.load_model_old_way(context, scheduler) as model:
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
conditioning=conditioning,
|
||||
scheduler=scheduler,
|
||||
init_image=image,
|
||||
mask_image=mask,
|
||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
exclude={"positive_conditioning", "negative_conditioning", "scheduler", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,7 @@ import einops
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import torch
|
||||
from diffusers import ControlNetModel
|
||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
@ -233,7 +233,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(scheduler, eta=0.0)#ddim_eta)
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||
scheduler,
|
||||
|
||||
# for ddim scheduler
|
||||
eta=0.0, #ddim_eta
|
||||
|
||||
# for ancestral and sde schedulers
|
||||
generator=torch.Generator(device=uc.device).manual_seed(0),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline:
|
||||
|
@ -43,12 +43,19 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
#fmt: on
|
||||
|
||||
|
||||
class SD1ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
class PipelineModelField(BaseModel):
|
||||
"""Pipeline model field"""
|
||||
|
||||
type: Literal["sd1_model_loader"] = "sd1_model_loader"
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
|
||||
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a pipeline model, outputting its submodels."""
|
||||
|
||||
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||
|
||||
model: PipelineModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
@ -57,22 +64,24 @@ class SD1ModelLoaderInvocation(BaseInvocation):
|
||||
"ui": {
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model_name": "model" # TODO: rename to model_name?
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = BaseModelType.StableDiffusion1 # TODO:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Pipeline
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
@ -107,142 +116,39 @@ class SD1ModelLoaderInvocation(BaseInvocation):
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: optimize(less code copy)
|
||||
class SD2ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
|
||||
type: Literal["sd2_model_loader"] = "sd2_model_loader"
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model_name": "model" # TODO: rename to model_name?
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = BaseModelType.StableDiffusion2 # TODO:
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
)
|
||||
|
254
invokeai/app/services/board_image_record_storage.py
Normal file
254
invokeai/app/services/board_image_record_storage.py
Normal file
@ -0,0 +1,254 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Union, cast
|
||||
from invokeai.app.services.board_record_storage import BoardRecord
|
||||
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
deserialize_image_record,
|
||||
)
|
||||
|
||||
|
||||
class BoardImageRecordStorageBase(ABC):
|
||||
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets images for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_count_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> int:
|
||||
"""Gets the number of images for a board."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `board_images` junction table."""
|
||||
|
||||
# Create the `board_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS board_images (
|
||||
board_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between boards and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (board_id) REFERENCES boards (board_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id ON board_images (board_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for board id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_board_images_board_id_created_at ON board_images (board_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_board_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON board_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE board_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE board_id = old.board_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT (image_name) DO UPDATE SET board_id = ?;
|
||||
""",
|
||||
(board_id, image_name, board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE board_id = ? AND image_name = ?;
|
||||
""",
|
||||
(board_id, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
# TODO: this isn't paginated yet?
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM board_images WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
return count
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
142
invokeai/app/services/board_images.py
Normal file
142
invokeai/app/services/board_images.py
Normal file
@ -0,0 +1,142 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from typing import List, Union
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardRecord,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.models.image_record import ImageDTO, image_record_to_dto
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardImagesServiceABC(ABC):
|
||||
"""High-level service for board-image relationship management."""
|
||||
|
||||
@abstractmethod
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Adds an image to a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Removes an image from a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets images for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
"""Gets an image's board id, if it has one."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardImagesServiceDependencies:
|
||||
"""Service dependencies for the BoardImagesService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardImagesService(BoardImagesServiceABC):
|
||||
_services: BoardImagesServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardImagesServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.add_image_to_board(board_id, image_name)
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
image_records = self._services.board_image_records.get_images_for_board(
|
||||
board_id
|
||||
)
|
||||
image_dtos = list(
|
||||
map(
|
||||
lambda r: image_record_to_dto(
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
board_id,
|
||||
),
|
||||
image_records.items,
|
||||
)
|
||||
)
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=image_records.offset,
|
||||
limit=image_records.limit,
|
||||
total=image_records.total,
|
||||
)
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Union[str, None]:
|
||||
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
||||
return board_id
|
||||
|
||||
|
||||
def board_record_to_dto(
|
||||
board_record: BoardRecord, cover_image_name: str | None, image_count: int
|
||||
) -> BoardDTO:
|
||||
"""Converts a board record to a board DTO."""
|
||||
return BoardDTO(
|
||||
**board_record.dict(exclude={'cover_image_name'}),
|
||||
cover_image_name=cover_image_name,
|
||||
image_count=image_count,
|
||||
)
|
329
invokeai/app/services/board_record_storage.py
Normal file
329
invokeai/app/services/board_record_storage.py
Normal file
@ -0,0 +1,329 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
import uuid
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.board_record import (
|
||||
BoardRecord,
|
||||
deserialize_board_record,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, Extra
|
||||
|
||||
|
||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||
board_name: Optional[str] = Field(description="The board's new name.")
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's new cover image."
|
||||
)
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
"""Raised when an board record is not found."""
|
||||
|
||||
def __init__(self, message="Board record not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordSaveException(Exception):
|
||||
"""Raised when an board record cannot be saved."""
|
||||
|
||||
def __init__(self, message="Board record not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordDeleteException(Exception):
|
||||
"""Raised when an board record cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Board record not deleted"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class BoardRecordStorageBase(ABC):
|
||||
"""Low-level service responsible for interfacing with the board record store."""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, board_id: str) -> None:
|
||||
"""Deletes a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
"""Saves a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
"""Gets a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
"""Updates a board record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets many board records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
"""Gets all board records."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_filename: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the `boards` table and `board_images` junction table."""
|
||||
|
||||
# Create the `boards` table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS boards (
|
||||
board_id TEXT NOT NULL PRIMARY KEY,
|
||||
board_name TEXT NOT NULL,
|
||||
cover_image_name TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
FOREIGN KEY (cover_image_name) REFERENCES images (image_name) ON DELETE SET NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_boards_created_at ON boards (created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_boards_updated_at
|
||||
AFTER UPDATE
|
||||
ON boards FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE boards SET updated_at = current_timestamp
|
||||
WHERE board_id = old.board_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = str(uuid.uuid4())
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(board_id, board_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.board_name, board_id),
|
||||
)
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
WHERE board_id = ?;
|
||||
""",
|
||||
(changes.cover_image_name, board_id),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?;
|
||||
""",
|
||||
(limit, offset),
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
|
||||
# Get the total number of boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](
|
||||
items=boards, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Get all the boards
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = list(map(lambda r: deserialize_board_record(dict(r)), result))
|
||||
|
||||
return boards
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
185
invokeai/app/services/boards.py
Normal file
185
invokeai/app/services/boards.py
Normal file
@ -0,0 +1,185 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from logging import Logger
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.board_images import board_record_to_dto
|
||||
|
||||
from invokeai.app.services.board_record_storage import (
|
||||
BoardChanges,
|
||||
BoardRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.board_record import BoardDTO
|
||||
from invokeai.app.services.urls import UrlServiceBase
|
||||
|
||||
|
||||
class BoardServiceABC(ABC):
|
||||
"""High-level service for board management."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
"""Creates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_dto(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> BoardDTO:
|
||||
"""Gets a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
"""Updates a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
board_id: str,
|
||||
) -> None:
|
||||
"""Deletes a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
"""Gets many boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(
|
||||
self,
|
||||
) -> list[BoardDTO]:
|
||||
"""Gets all boards."""
|
||||
pass
|
||||
|
||||
|
||||
class BoardServiceDependencies:
|
||||
"""Service dependencies for the BoardService."""
|
||||
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
board_records: BoardRecordStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
board_record_storage: BoardRecordStorageBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
):
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.image_records = image_record_storage
|
||||
self.board_records = board_record_storage
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class BoardService(BoardServiceABC):
|
||||
_services: BoardServiceDependencies
|
||||
|
||||
def __init__(self, services: BoardServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
board_name: str,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.save(board_name)
|
||||
return board_record_to_dto(board_record, None, 0)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
board_record = self._services.board_records.get(board_id)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def update(
|
||||
self,
|
||||
board_id: str,
|
||||
changes: BoardChanges,
|
||||
) -> BoardDTO:
|
||||
board_record = self._services.board_records.update(board_id, changes)
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
board_record.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
board_id
|
||||
)
|
||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
self._services.board_records.delete(board_id)
|
||||
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self._services.board_records.get_many(offset, limit)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](
|
||||
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
||||
)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
board_records = self._services.board_records.get_all()
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||
r.board_id
|
||||
)
|
||||
if cover_image:
|
||||
cover_image_name = cover_image.image_name
|
||||
else:
|
||||
cover_image_name = None
|
||||
|
||||
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||
r.board_id
|
||||
)
|
||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||
|
||||
return board_dtos
|
@ -82,6 +82,7 @@ class ImageRecordStorageBase(ABC):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
@ -109,6 +110,11 @@ class ImageRecordStorageBase(ABC):
|
||||
"""Saves an image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> ImageRecord | None:
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_filename: str
|
||||
@ -135,7 +141,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
"""Creates the tables for the `images` database."""
|
||||
"""Creates the `images` table."""
|
||||
|
||||
# Create the `images` table.
|
||||
self._cursor.execute(
|
||||
@ -152,6 +158,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
node_id TEXT,
|
||||
metadata TEXT,
|
||||
is_intermediate BOOLEAN DEFAULT FALSE,
|
||||
board_id TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
@ -190,7 +197,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
AFTER UPDATE
|
||||
ON images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE images SET updated_at = current_timestamp
|
||||
UPDATE images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
@ -259,6 +266,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -273,38 +281,66 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
|
||||
images_query = f"""SELECT * FROM images WHERE 1=1\n"""
|
||||
images_query = """--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += f"""AND image_origin = ?\n"""
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = list(map(lambda c: c.value, set(categories)))
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
|
||||
if board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
|
||||
query_params.append(board_id)
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
@ -321,7 +357,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
count = self._cursor.fetchone()[0]
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
@ -412,3 +448,28 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_most_recent_image_for_board(
|
||||
self, board_id: str
|
||||
) -> Union[ImageRecord, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
@ -10,6 +10,7 @@ from invokeai.app.models.image import (
|
||||
InvalidOriginException,
|
||||
)
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
@ -49,7 +50,7 @@ class ImageServiceABC(ABC):
|
||||
image_category: ImageCategory,
|
||||
node_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
intermediate: bool = False,
|
||||
is_intermediate: bool = False,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@ -79,7 +80,7 @@ class ImageServiceABC(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str) -> str:
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's path."""
|
||||
pass
|
||||
|
||||
@ -101,6 +102,7 @@ class ImageServiceABC(ABC):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
@ -114,8 +116,9 @@ class ImageServiceABC(ABC):
|
||||
class ImageServiceDependencies:
|
||||
"""Service dependencies for the ImageService."""
|
||||
|
||||
records: ImageRecordStorageBase
|
||||
files: ImageFileStorageBase
|
||||
image_records: ImageRecordStorageBase
|
||||
image_files: ImageFileStorageBase
|
||||
board_image_records: BoardImageRecordStorageBase
|
||||
metadata: MetadataServiceBase
|
||||
urls: UrlServiceBase
|
||||
logger: Logger
|
||||
@ -126,14 +129,16 @@ class ImageServiceDependencies:
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
board_image_record_storage: BoardImageRecordStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self.records = image_record_storage
|
||||
self.files = image_file_storage
|
||||
self.image_records = image_record_storage
|
||||
self.image_files = image_file_storage
|
||||
self.board_image_records = board_image_record_storage
|
||||
self.metadata = metadata
|
||||
self.urls = url
|
||||
self.logger = logger
|
||||
@ -144,25 +149,8 @@ class ImageServiceDependencies:
|
||||
class ImageService(ImageServiceABC):
|
||||
_services: ImageServiceDependencies
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_record_storage: ImageRecordStorageBase,
|
||||
image_file_storage: ImageFileStorageBase,
|
||||
metadata: MetadataServiceBase,
|
||||
url: UrlServiceBase,
|
||||
logger: Logger,
|
||||
names: NameServiceBase,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
):
|
||||
self._services = ImageServiceDependencies(
|
||||
image_record_storage=image_record_storage,
|
||||
image_file_storage=image_file_storage,
|
||||
metadata=metadata,
|
||||
url=url,
|
||||
logger=logger,
|
||||
names=names,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
)
|
||||
def __init__(self, services: ImageServiceDependencies):
|
||||
self._services = services
|
||||
|
||||
def create(
|
||||
self,
|
||||
@ -187,7 +175,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
try:
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
created_at = self._services.records.save(
|
||||
self._services.image_records.save(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
@ -202,35 +190,15 @@ class ImageService(ImageServiceABC):
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
self._services.files.save(
|
||||
self._services.image_files.save(
|
||||
image_name=image_name,
|
||||
image=image,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
image_url = self._services.urls.get_image_url(image_name)
|
||||
thumbnail_url = self._services.urls.get_image_url(image_name, True)
|
||||
image_dto = self.get_dto(image_name)
|
||||
|
||||
return ImageDTO(
|
||||
# Non-nullable fields
|
||||
image_name=image_name,
|
||||
image_origin=image_origin,
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
# Nullable fields
|
||||
node_id=node_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata,
|
||||
# Meta fields
|
||||
created_at=created_at,
|
||||
updated_at=created_at, # this is always the same as the created_at at this time
|
||||
deleted_at=None,
|
||||
is_intermediate=is_intermediate,
|
||||
# Extra non-nullable fields for DTO
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
)
|
||||
return image_dto
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to save image record")
|
||||
raise
|
||||
@ -247,7 +215,7 @@ class ImageService(ImageServiceABC):
|
||||
changes: ImageRecordChanges,
|
||||
) -> ImageDTO:
|
||||
try:
|
||||
self._services.records.update(image_name, changes)
|
||||
self._services.image_records.update(image_name, changes)
|
||||
return self.get_dto(image_name)
|
||||
except ImageRecordSaveException:
|
||||
self._services.logger.error("Failed to update image record")
|
||||
@ -258,7 +226,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||
try:
|
||||
return self._services.files.get(image_name)
|
||||
return self._services.image_files.get(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self._services.logger.error("Failed to get image file")
|
||||
raise
|
||||
@ -268,7 +236,7 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_record(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
return self._services.records.get(image_name)
|
||||
return self._services.image_records.get(image_name)
|
||||
except ImageRecordNotFoundException:
|
||||
self._services.logger.error("Image record not found")
|
||||
raise
|
||||
@ -278,12 +246,13 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
try:
|
||||
image_record = self._services.records.get(image_name)
|
||||
image_record = self._services.image_records.get(image_name)
|
||||
|
||||
image_dto = image_record_to_dto(
|
||||
image_record,
|
||||
self._services.urls.get_image_url(image_name),
|
||||
self._services.urls.get_image_url(image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(image_name),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@ -296,14 +265,14 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
return self._services.files.get_path(image_name, thumbnail)
|
||||
return self._services.image_files.get_path(image_name, thumbnail)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem getting image path")
|
||||
raise e
|
||||
|
||||
def validate_path(self, path: str) -> bool:
|
||||
try:
|
||||
return self._services.files.validate_path(path)
|
||||
return self._services.image_files.validate_path(path)
|
||||
except Exception as e:
|
||||
self._services.logger.error("Problem validating image path")
|
||||
raise e
|
||||
@ -322,14 +291,16 @@ class ImageService(ImageServiceABC):
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
results = self._services.image_records.get_many(
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
)
|
||||
|
||||
image_dtos = list(
|
||||
@ -338,6 +309,9 @@ class ImageService(ImageServiceABC):
|
||||
r,
|
||||
self._services.urls.get_image_url(r.image_name),
|
||||
self._services.urls.get_image_url(r.image_name, True),
|
||||
self._services.board_image_records.get_board_for_image(
|
||||
r.image_name
|
||||
),
|
||||
),
|
||||
results.items,
|
||||
)
|
||||
@ -355,8 +329,8 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def delete(self, image_name: str):
|
||||
try:
|
||||
self._services.files.delete(image_name)
|
||||
self._services.records.delete(image_name)
|
||||
self._services.image_files.delete(image_name)
|
||||
self._services.image_records.delete(image_name)
|
||||
except ImageRecordDeleteException:
|
||||
self._services.logger.error(f"Failed to delete image record")
|
||||
raise
|
||||
|
@ -4,7 +4,9 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||
from invokeai.app.services.boards import BoardServiceABC
|
||||
from invokeai.app.services.images import ImageServiceABC
|
||||
from invokeai.backend import ModelManager
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
@ -26,9 +28,9 @@ class InvocationServices:
|
||||
model_manager: "ModelManager"
|
||||
restoration: "RestorationServices"
|
||||
configuration: "InvokeAISettings"
|
||||
images: "ImageService"
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
images: "ImageServiceABC"
|
||||
boards: "BoardServiceABC"
|
||||
board_images: "BoardImagesServiceABC"
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
@ -39,7 +41,9 @@ class InvocationServices:
|
||||
events: "EventServiceBase",
|
||||
logger: "Logger",
|
||||
latents: "LatentsStorageBase",
|
||||
images: "ImageService",
|
||||
images: "ImageServiceABC",
|
||||
boards: "BoardServiceABC",
|
||||
board_images: "BoardImagesServiceABC",
|
||||
queue: "InvocationQueueABC",
|
||||
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||
@ -52,9 +56,12 @@ class InvocationServices:
|
||||
self.logger = logger
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.boards = boards
|
||||
self.board_images = board_images
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
self.configuration = configuration
|
||||
self.boards = boards
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||
from dataclasses import dataclass
|
||||
|
||||
from invokeai.backend.model_management.model_manager import (
|
||||
@ -69,19 +69,6 @@ class ModelManagerServiceBase(ABC):
|
||||
) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns the name and typeof the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
||||
"""Sets the default model to the indicated name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
@ -270,17 +257,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_type,
|
||||
)
|
||||
|
||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
return self.mgr.default_model()
|
||||
|
||||
def set_default_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType):
|
||||
"""Sets the default model to the indicated name."""
|
||||
self.mgr.set_default_model(model_name, base_model, model_type)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
@ -297,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None
|
||||
) -> dict:
|
||||
) -> list[dict]:
|
||||
# ) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
|
62
invokeai/app/services/models/board_record.py
Normal file
62
invokeai/app/services/models/board_record.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
|
||||
|
||||
class BoardRecord(BaseModel):
|
||||
"""Deserialized board record."""
|
||||
|
||||
board_id: str = Field(description="The unique ID of the board.")
|
||||
"""The unique ID of the board."""
|
||||
board_name: str = Field(description="The name of the board.")
|
||||
"""The name of the board."""
|
||||
created_at: Union[datetime, str] = Field(
|
||||
description="The created timestamp of the board."
|
||||
)
|
||||
"""The created timestamp of the image."""
|
||||
updated_at: Union[datetime, str] = Field(
|
||||
description="The updated timestamp of the board."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
deleted_at: Union[datetime, str, None] = Field(
|
||||
description="The deleted timestamp of the board."
|
||||
)
|
||||
"""The updated timestamp of the image."""
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the cover image of the board."
|
||||
)
|
||||
"""The name of the cover image of the board."""
|
||||
|
||||
|
||||
class BoardDTO(BoardRecord):
|
||||
"""Deserialized board record with cover image URL and image count."""
|
||||
|
||||
cover_image_name: Optional[str] = Field(
|
||||
description="The name of the board's cover image."
|
||||
)
|
||||
"""The URL of the thumbnail of the most recent image in the board."""
|
||||
image_count: int = Field(description="The number of images in the board.")
|
||||
"""The number of images in the board."""
|
||||
|
||||
|
||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||
"""Deserializes a board record."""
|
||||
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
board_id = board_dict.get("board_id", "unknown")
|
||||
board_name = board_dict.get("board_name", "unknown")
|
||||
cover_image_name = board_dict.get("cover_image_name", "unknown")
|
||||
created_at = board_dict.get("created_at", get_iso_timestamp())
|
||||
updated_at = board_dict.get("updated_at", get_iso_timestamp())
|
||||
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
|
||||
|
||||
return BoardRecord(
|
||||
board_id=board_id,
|
||||
board_name=board_name,
|
||||
cover_image_name=cover_image_name,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
deleted_at=deleted_at,
|
||||
)
|
@ -86,19 +86,24 @@ class ImageUrlsDTO(BaseModel):
|
||||
|
||||
|
||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
"""Deserialized image record, enriched for the frontend with URLs."""
|
||||
"""Deserialized image record, enriched for the frontend."""
|
||||
|
||||
board_id: Union[str, None] = Field(
|
||||
description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
pass
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str
|
||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Union[str, None]
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
**image_record.dict(),
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -5,7 +5,6 @@ from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
|
@ -5,7 +5,6 @@ from .base import (
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
Img2Img,
|
||||
Inpaint,
|
||||
Generator,
|
||||
|
@ -29,7 +29,6 @@ import invokeai.backend.util.logging as logger
|
||||
from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
@ -81,8 +80,10 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
self.params=params
|
||||
self.kwargs = kwargs
|
||||
|
||||
def generate(self,
|
||||
prompt: str='',
|
||||
def generate(
|
||||
self,
|
||||
conditioning: tuple,
|
||||
scheduler,
|
||||
callback: Optional[Callable]=None,
|
||||
step_callback: Optional[Callable]=None,
|
||||
iterations: int=1,
|
||||
@ -116,11 +117,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
model_name = model_info.name
|
||||
model_hash = model_info.hash
|
||||
with model_info.context as model:
|
||||
scheduler: Scheduler = self.get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||
if self.params.variation_amount > 0:
|
||||
@ -143,8 +139,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
results = generator.generate(
|
||||
conditioning=conditioning,
|
||||
step_callback=step_callback,
|
||||
sampler=scheduler,
|
||||
**generator_args,
|
||||
@ -170,20 +166,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
|
||||
scheduler_config = model.scheduler.config
|
||||
if "_backup" in scheduler_config:
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls)->Type[Generator]:
|
||||
'''
|
||||
@ -193,13 +175,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
return Generator
|
||||
|
||||
# ------------------------------------
|
||||
class Txt2Img(InvokeAIGenerator):
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .txt2img import Txt2Img
|
||||
return Txt2Img
|
||||
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
def generate(self,
|
||||
@ -253,24 +228,6 @@ class Inpaint(Img2Img):
|
||||
from .inpaint import Inpaint
|
||||
return Inpaint
|
||||
|
||||
# ------------------------------------
|
||||
class Embiggen(Txt2Img):
|
||||
def generate(
|
||||
self,
|
||||
embiggen: list=None,
|
||||
embiggen_tiles: list = None,
|
||||
strength: float=0.75,
|
||||
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
||||
return super().generate(embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
strength=strength,
|
||||
**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .embiggen import Embiggen
|
||||
return Embiggen
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
@ -281,7 +238,7 @@ class Generator:
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.latent_channels = model.unet.config.in_channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
@ -292,7 +249,7 @@ class Generator:
|
||||
self.free_gpu_mem = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self, prompt, **kwargs):
|
||||
def get_make_image(self, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
@ -308,7 +265,6 @@ class Generator:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
width,
|
||||
height,
|
||||
sampler,
|
||||
@ -333,7 +289,6 @@ class Generator:
|
||||
saver.get_stacked_maps_image()
|
||||
)
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler=sampler,
|
||||
init_image=init_image,
|
||||
width=width,
|
||||
|
@ -1,559 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.embiggen descends from .generator
|
||||
and generates with .generator.img2img
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from .base import Generator
|
||||
from .img2img import Img2Img
|
||||
|
||||
class Embiggen(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None
|
||||
|
||||
# Replace generate because Embiggen doesn't need/use most of what it does normallly
|
||||
def generate(
|
||||
self,
|
||||
prompt,
|
||||
iterations=1,
|
||||
seed=None,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
make_image = self.get_make_image(prompt, step_callback=step_callback, **kwargs)
|
||||
results = []
|
||||
seed = seed if seed else self.new_seed()
|
||||
|
||||
# Noise will be generated by the Img2Img generator when called
|
||||
for _ in trange(iterations, desc="Generating"):
|
||||
# make_image will call Img2Img which will do the equivalent of get_noise itself
|
||||
image = make_image()
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, prompt_in=prompt)
|
||||
seed = self.new_seed()
|
||||
return results
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
init_img,
|
||||
strength,
|
||||
width,
|
||||
height,
|
||||
embiggen,
|
||||
embiggen_tiles,
|
||||
step_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
"""
|
||||
assert (
|
||||
not sampler.uses_inpainting_model()
|
||||
), "--embiggen is not supported by inpainting models"
|
||||
|
||||
# Construct embiggen arg array, and sanity check arguments
|
||||
if embiggen == None: # embiggen can also be called with just embiggen_tiles
|
||||
embiggen = [1.0] # If not specified, assume no scaling
|
||||
elif embiggen[0] < 0:
|
||||
embiggen[0] = 1.0
|
||||
logger.warning(
|
||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||
)
|
||||
if len(embiggen) < 2:
|
||||
embiggen.append(0.75)
|
||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||
embiggen[1] = 0.75
|
||||
logger.warning(
|
||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||
)
|
||||
if len(embiggen) < 3:
|
||||
embiggen.append(0.25)
|
||||
elif embiggen[2] < 0:
|
||||
embiggen[2] = 0.25
|
||||
logger.warning(
|
||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||
)
|
||||
|
||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||
# and then sort them, because... people.
|
||||
if embiggen_tiles:
|
||||
embiggen_tiles = list(map(lambda n: n - 1, embiggen_tiles))
|
||||
embiggen_tiles.sort()
|
||||
|
||||
if strength >= 0.5:
|
||||
logger.warning(
|
||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||
)
|
||||
|
||||
# Prep img2img generator, since we wrap over it
|
||||
gen_img2img = Img2Img(self.model, self.precision)
|
||||
|
||||
# Open original init image (not a tensor) to manipulate
|
||||
initsuperimage = Image.open(init_img)
|
||||
|
||||
with Image.open(init_img) as img:
|
||||
initsuperimage = img.convert("RGB")
|
||||
|
||||
# Size of the target super init image in pixels
|
||||
initsuperwidth, initsuperheight = initsuperimage.size
|
||||
|
||||
# Increase by scaling factor if not already resized, using ESRGAN as able
|
||||
if embiggen[0] != 1.0:
|
||||
initsuperwidth = round(initsuperwidth * embiggen[0])
|
||||
initsuperheight = round(initsuperheight * embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ..restoration.realesrgan import ESRGAN
|
||||
|
||||
esrgan = ESRGAN()
|
||||
logger.info(
|
||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||
)
|
||||
if embiggen[0] > 2:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
4, # upscale scale
|
||||
)
|
||||
else:
|
||||
initsuperimage = esrgan.process(
|
||||
initsuperimage,
|
||||
embiggen[1], # upscale strength
|
||||
self.seed,
|
||||
2, # upscale scale
|
||||
)
|
||||
# We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
|
||||
# but from personal experiance it doesn't greatly improve anything after 4x
|
||||
# Resize to target scaling factor resolution
|
||||
initsuperimage = initsuperimage.resize(
|
||||
(initsuperwidth, initsuperheight), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Use width and height as tile widths and height
|
||||
# Determine buffer size in pixels
|
||||
if embiggen[2] < 1:
|
||||
if embiggen[2] < 0:
|
||||
embiggen[2] = 0
|
||||
overlap_size_x = round(embiggen[2] * width)
|
||||
overlap_size_y = round(embiggen[2] * height)
|
||||
else:
|
||||
overlap_size_x = round(embiggen[2])
|
||||
overlap_size_y = round(embiggen[2])
|
||||
|
||||
# With overall image width and height known, determine how many tiles we need
|
||||
def ceildiv(a, b):
|
||||
return -1 * (-a // b)
|
||||
|
||||
# X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
|
||||
# (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
|
||||
# (width - overlap_size_x) is how much new we can fill with a single tile
|
||||
emb_tiles_x = 1
|
||||
emb_tiles_y = 1
|
||||
if (initsuperwidth - width) > 0:
|
||||
emb_tiles_x = ceildiv(initsuperwidth - width, width - overlap_size_x) + 1
|
||||
if (initsuperheight - height) > 0:
|
||||
emb_tiles_y = ceildiv(initsuperheight - height, height - overlap_size_y) + 1
|
||||
# Sanity
|
||||
assert (
|
||||
emb_tiles_x > 1 or emb_tiles_y > 1
|
||||
), f"ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don't need to Embiggen! Check your arguments."
|
||||
|
||||
# Prep alpha layers --------------
|
||||
# https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
|
||||
# agradientL is Left-side transparent
|
||||
agradientL = (
|
||||
Image.linear_gradient("L").rotate(90).resize((overlap_size_x, height))
|
||||
)
|
||||
# agradientT is Top-side transparent
|
||||
agradientT = Image.linear_gradient("L").resize((width, overlap_size_y))
|
||||
# radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
|
||||
agradientC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
# Find distance to lower right corner (numpy takes arrays)
|
||||
distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
|
||||
# Clamp values to max 255
|
||||
if distanceToLR > 255:
|
||||
distanceToLR = 255
|
||||
# Place the pixel as invert of distance
|
||||
agradientC.putpixel((x, y), round(255 - distanceToLR))
|
||||
|
||||
# Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
|
||||
# Fits for a left-fading gradient on the bottom side and full opacity on the right side.
|
||||
agradientAsymC = Image.new("L", (256, 256))
|
||||
for y in range(256):
|
||||
for x in range(256):
|
||||
value = round(max(0, x - (255 - y)) * (255 / max(1, y)))
|
||||
# Clamp values
|
||||
value = max(0, value)
|
||||
value = min(255, value)
|
||||
agradientAsymC.putpixel((x, y), value)
|
||||
|
||||
# Create alpha layers default fully white
|
||||
alphaLayerL = Image.new("L", (width, height), 255)
|
||||
alphaLayerT = Image.new("L", (width, height), 255)
|
||||
alphaLayerLTC = Image.new("L", (width, height), 255)
|
||||
# Paste gradients into alpha layers
|
||||
alphaLayerL.paste(agradientL, (0, 0))
|
||||
alphaLayerT.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientL, (0, 0))
|
||||
alphaLayerLTC.paste(agradientT, (0, 0))
|
||||
alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
|
||||
# make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
|
||||
# to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
|
||||
alphaLayerTaC = alphaLayerT.copy()
|
||||
alphaLayerTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
alphaLayerLTaC = alphaLayerLTC.copy()
|
||||
alphaLayerLTaC.paste(
|
||||
agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
if embiggen_tiles:
|
||||
# Individual unconnected sides
|
||||
alphaLayerR = Image.new("L", (width, height), 255)
|
||||
alphaLayerR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerB = Image.new("L", (width, height), 255)
|
||||
alphaLayerB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerTB = Image.new("L", (width, height), 255)
|
||||
alphaLayerTB.paste(agradientT, (0, 0))
|
||||
alphaLayerTB.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLR = Image.new("L", (width, height), 255)
|
||||
alphaLayerLR.paste(agradientL, (0, 0))
|
||||
alphaLayerLR.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
|
||||
# Sides and corner Layers
|
||||
alphaLayerRBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRBC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerRBC.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerLBC = Image.new("L", (width, height), 255)
|
||||
alphaLayerLBC.paste(agradientL, (0, 0))
|
||||
alphaLayerLBC.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerLBC.paste(
|
||||
agradientC.rotate(90).resize((overlap_size_x, overlap_size_y)),
|
||||
(0, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerRTC = Image.new("L", (width, height), 255)
|
||||
alphaLayerRTC.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerRTC.paste(agradientT, (0, 0))
|
||||
alphaLayerRTC.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# All but X layers
|
||||
alphaLayerABT = Image.new("L", (width, height), 255)
|
||||
alphaLayerABT.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABT.paste(agradientL.rotate(180), (width - overlap_size_x, 0))
|
||||
alphaLayerABT.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABL = Image.new("L", (width, height), 255)
|
||||
alphaLayerABL.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABL.paste(agradientT.rotate(180), (0, height - overlap_size_y))
|
||||
alphaLayerABL.paste(
|
||||
agradientC.rotate(180).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, height - overlap_size_y),
|
||||
)
|
||||
alphaLayerABR = Image.new("L", (width, height), 255)
|
||||
alphaLayerABR.paste(alphaLayerLBC, (0, 0))
|
||||
alphaLayerABR.paste(agradientT, (0, 0))
|
||||
alphaLayerABR.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerABB = Image.new("L", (width, height), 255)
|
||||
alphaLayerABB.paste(alphaLayerRTC, (0, 0))
|
||||
alphaLayerABB.paste(agradientL, (0, 0))
|
||||
alphaLayerABB.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
|
||||
# All-around layer
|
||||
alphaLayerAA = Image.new("L", (width, height), 255)
|
||||
alphaLayerAA.paste(alphaLayerABT, (0, 0))
|
||||
alphaLayerAA.paste(agradientT, (0, 0))
|
||||
alphaLayerAA.paste(
|
||||
agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0)
|
||||
)
|
||||
alphaLayerAA.paste(
|
||||
agradientC.rotate(270).resize((overlap_size_x, overlap_size_y)),
|
||||
(width - overlap_size_x, 0),
|
||||
)
|
||||
|
||||
# Clean up temporary gradients
|
||||
del agradientL
|
||||
del agradientT
|
||||
del agradientC
|
||||
|
||||
def make_image():
|
||||
# Make main tiles -------------------------------------------------
|
||||
if embiggen_tiles:
|
||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||
else:
|
||||
logger.info(
|
||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||
)
|
||||
|
||||
emb_tile_store = []
|
||||
# Although we could use the same seed for every tile for determinism, at higher strengths this may
|
||||
# produce duplicated structures for each tile and make the tiling effect more obvious
|
||||
# instead track and iterate a local seed we pass to Img2Img
|
||||
seed = self.seed
|
||||
seedintlimit = (
|
||||
np.iinfo(np.uint32).max - 1
|
||||
) # only retreive this one from numpy
|
||||
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
# Don't iterate on first tile
|
||||
if tile != 0:
|
||||
if seed < seedintlimit:
|
||||
seed += 1
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
# Determine if this is a re-run and replace
|
||||
if embiggen_tiles and not tile in embiggen_tiles:
|
||||
continue
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
# Determine bounds to cut up the init image
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
right = left + width
|
||||
bottom = top + height
|
||||
|
||||
# Cropped image of above dimension (does not modify the original)
|
||||
newinitimage = initsuperimage.crop((left, top, right, bottom))
|
||||
# DEBUG:
|
||||
# newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
|
||||
# newinitimage.save(newinitimagepath)
|
||||
|
||||
if embiggen_tiles:
|
||||
logger.debug(
|
||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||
|
||||
# create a torch tensor from an Image
|
||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||
newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
|
||||
newinitimage = torch.from_numpy(newinitimage)
|
||||
newinitimage = 2.0 * newinitimage - 1.0
|
||||
newinitimage = newinitimage.to(self.model.device)
|
||||
clear_cuda_cache = (
|
||||
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||
)
|
||||
|
||||
tile_results = gen_img2img.generate(
|
||||
prompt,
|
||||
iterations=1,
|
||||
seed=seed,
|
||||
sampler=sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=conditioning,
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=None, # called only after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_image=newinitimage, # notice that init_image is different from init_img
|
||||
mask_image=None,
|
||||
strength=strength,
|
||||
clear_cuda_cache=clear_cuda_cache,
|
||||
)
|
||||
|
||||
emb_tile_store.append(tile_results[0][0])
|
||||
# DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
|
||||
# emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
|
||||
del newinitimage
|
||||
|
||||
# Sanity check we have them all
|
||||
if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (
|
||||
embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)
|
||||
):
|
||||
outputsuperimage = Image.new("RGBA", (initsuperwidth, initsuperheight))
|
||||
if embiggen_tiles:
|
||||
outputsuperimage.alpha_composite(
|
||||
initsuperimage.convert("RGBA"), (0, 0)
|
||||
)
|
||||
for tile in range(emb_tiles_x * emb_tiles_y):
|
||||
if embiggen_tiles:
|
||||
if tile in embiggen_tiles:
|
||||
intileimage = emb_tile_store.pop(0)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
intileimage = emb_tile_store[tile]
|
||||
intileimage = intileimage.convert("RGBA")
|
||||
# Get row and column entries
|
||||
emb_row_i = tile // emb_tiles_x
|
||||
emb_column_i = tile % emb_tiles_x
|
||||
if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
|
||||
left = 0
|
||||
top = 0
|
||||
else:
|
||||
# Determine upper-left point
|
||||
if emb_column_i + 1 == emb_tiles_x:
|
||||
left = initsuperwidth - width
|
||||
else:
|
||||
left = round(emb_column_i * (width - overlap_size_x))
|
||||
if emb_row_i + 1 == emb_tiles_y:
|
||||
top = initsuperheight - height
|
||||
else:
|
||||
top = round(emb_row_i * (height - overlap_size_y))
|
||||
# Handle gradients for various conditions
|
||||
# Handle emb_rerun case
|
||||
if embiggen_tiles:
|
||||
# top of image
|
||||
if emb_row_i == 0:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) not in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerB)
|
||||
# Otherwise do nothing on this tile
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRBC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLBC)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerLR)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABT)
|
||||
# bottom of image
|
||||
elif emb_row_i == emb_tiles_y - 1:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
# No tiles to look ahead to
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
# vertical middle of image
|
||||
else:
|
||||
if emb_column_i == 0:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTB)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerRTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABL)
|
||||
elif emb_column_i == emb_tiles_x - 1:
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
else:
|
||||
if (tile + 1) in embiggen_tiles: # Look-ahead right
|
||||
if (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerABR)
|
||||
elif (
|
||||
tile + emb_tiles_x
|
||||
) in embiggen_tiles: # Look-ahead down only
|
||||
intileimage.putalpha(alphaLayerABB)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerAA)
|
||||
# Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
|
||||
else:
|
||||
if emb_row_i == 0 and emb_column_i >= 1:
|
||||
intileimage.putalpha(alphaLayerL)
|
||||
elif emb_row_i >= 1 and emb_column_i == 0:
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerT)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerTaC)
|
||||
else:
|
||||
if (
|
||||
emb_column_i + 1 == emb_tiles_x
|
||||
): # If we don't have anything that can be placed to the right
|
||||
intileimage.putalpha(alphaLayerLTC)
|
||||
else:
|
||||
intileimage.putalpha(alphaLayerLTaC)
|
||||
# Layer tile onto final image
|
||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||
else:
|
||||
logger.error(
|
||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||
)
|
||||
|
||||
# after internal loops and patching up return Embiggen image
|
||||
return outputsuperimage
|
||||
|
||||
# end of function declaration
|
||||
return make_image
|
@ -22,7 +22,6 @@ class Img2Img(Generator):
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
|
@ -161,9 +161,7 @@ class Inpaint(Img2Img):
|
||||
im: Image.Image,
|
||||
seam_size: int,
|
||||
seam_blur: int,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -177,8 +175,6 @@ class Inpaint(Img2Img):
|
||||
mask = self.mask_edge(hard_mask, seam_size, seam_blur)
|
||||
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -203,8 +199,6 @@ class Inpaint(Img2Img):
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -306,7 +300,6 @@ class Inpaint(Img2Img):
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
@ -345,9 +338,7 @@ class Inpaint(Img2Img):
|
||||
result,
|
||||
seam_size,
|
||||
seam_blur,
|
||||
prompt,
|
||||
seed,
|
||||
sampler,
|
||||
seam_steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
@ -360,8 +351,6 @@ class Inpaint(Img2Img):
|
||||
|
||||
# Restore original settings
|
||||
self.get_make_image(
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
|
@ -1,125 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.controlnet import MultiControlNetModel
|
||||
|
||||
from ..stable_diffusion import (
|
||||
ConditioningData,
|
||||
PostprocessingSettings,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .base import Generator
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision,
|
||||
control_model: Optional[Union[ControlNetModel, List[ControlNetModel]]] = None,
|
||||
**kwargs):
|
||||
self.control_model = control_model
|
||||
if isinstance(self.control_model, list):
|
||||
self.control_model = MultiControlNetModel(self.control_model)
|
||||
super().__init__(model, precision, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt,
|
||||
sampler,
|
||||
steps,
|
||||
cfg_scale,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width,
|
||||
height,
|
||||
step_callback=None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
control_image = kwargs.get("control_image", None)
|
||||
do_classifier_free_guidance = cfg_scale > 1.0
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.control_model = self.control_model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=warmup,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
if control_image is not None:
|
||||
if isinstance(self.control_model, ControlNetModel):
|
||||
control_image = pipeline.prepare_control_image(
|
||||
image=control_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
elif isinstance(self.control_model, MultiControlNetModel):
|
||||
images = []
|
||||
for image_ in control_image:
|
||||
image_ = pipeline.prepare_control_image(
|
||||
image=image_,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=self.control_model.device,
|
||||
dtype=self.control_model.dtype,
|
||||
)
|
||||
images.append(image_)
|
||||
control_image = images
|
||||
kwargs["control_image"] = control_image
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
|
||||
noise=x_T,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
return make_image
|
@ -1,209 +0,0 @@
|
||||
"""
|
||||
invokeai.backend.generator.txt2img inherits from invokeai.backend.generator
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||
|
||||
from ..stable_diffusion import PostprocessingSettings
|
||||
from .base import Generator
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__(model, precision)
|
||||
self.init_latent = None # for get_noise()
|
||||
|
||||
def get_make_image(
|
||||
self,
|
||||
prompt: str,
|
||||
sampler,
|
||||
steps: int,
|
||||
cfg_scale: float,
|
||||
ddim_eta,
|
||||
conditioning,
|
||||
width: int,
|
||||
height: int,
|
||||
strength: float,
|
||||
step_callback: Optional[Callable] = None,
|
||||
threshold=0.0,
|
||||
warmup=0.2,
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
attention_maps_callback=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and the initial image
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
pipeline: StableDiffusionGeneratorPipeline = self.model
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=threshold,
|
||||
warmup=0.2,
|
||||
h_symmetry_time_pct=h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||
|
||||
def make_image(x_T: torch.Tensor, _: int):
|
||||
first_pass_latent_output, _ = pipeline.latents_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
noise=x_T,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
# Get our initial generation width and height directly from the latent output so
|
||||
# the message below is accurate.
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
logger.info(
|
||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
|
||||
# resizing
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
first_pass_latent_output,
|
||||
size=(
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor,
|
||||
),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
# Free up memory from the last generation.
|
||||
clear_cuda_cache = kwargs["clear_cuda_cache"] or None
|
||||
if clear_cuda_cache is not None:
|
||||
clear_cuda_cache()
|
||||
|
||||
second_pass_noise = self.get_noise_like(
|
||||
resized_latents, override_perlin=True
|
||||
)
|
||||
|
||||
# Clear symmetry for the second pass
|
||||
from dataclasses import replace
|
||||
|
||||
new_postprocessing_settings = replace(
|
||||
conditioning_data.postprocessing_settings, h_symmetry_time_pct=None
|
||||
)
|
||||
new_postprocessing_settings = replace(
|
||||
new_postprocessing_settings, v_symmetry_time_pct=None
|
||||
)
|
||||
new_conditioning_data = replace(
|
||||
conditioning_data, postprocessing_settings=new_postprocessing_settings
|
||||
)
|
||||
|
||||
verbosity = get_verbosity()
|
||||
set_verbosity_error()
|
||||
pipeline_output = pipeline.img2img_from_latents_and_embeddings(
|
||||
resized_latents,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=new_conditioning_data,
|
||||
strength=strength,
|
||||
noise=second_pass_noise,
|
||||
callback=step_callback,
|
||||
)
|
||||
set_verbosity(verbosity)
|
||||
|
||||
if (
|
||||
pipeline_output.attention_map_saver is not None
|
||||
and attention_maps_callback is not None
|
||||
):
|
||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||
|
||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||
|
||||
# FIXME: do we really need something entirely different for the inpainting model?
|
||||
|
||||
# in the case of the inpainting model being loaded, the trick of
|
||||
# providing an interpolated latent doesn't work, so we transiently
|
||||
# create a 512x512 PIL image, upscale it, and run the inpainting
|
||||
# over it in img2img mode. Because the inpaing model is so conservative
|
||||
# it doesn't change the image (much)
|
||||
|
||||
return make_image
|
||||
|
||||
def get_noise_like(self, like: torch.Tensor, override_perlin: bool = False):
|
||||
device = like.device
|
||||
if device.type == "mps":
|
||||
x = torch.randn_like(like, device="cpu", dtype=self.torch_dtype()).to(
|
||||
device
|
||||
)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0 and override_perlin == False:
|
||||
shape = like.shape
|
||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||
shape[3], shape[2]
|
||||
)
|
||||
return x
|
||||
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self, width, height, scale=True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
if scale:
|
||||
# Scale the input width and height for the initial generation
|
||||
# Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144),
|
||||
# while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256)
|
||||
|
||||
aspect = width / height
|
||||
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
model_area = (
|
||||
dimension * dimension
|
||||
) # hardcoded for now since all models are trained on square images
|
||||
|
||||
if aspect > 1.0:
|
||||
init_height = max(min_dimension, math.sqrt(model_area / aspect))
|
||||
init_width = init_height * aspect
|
||||
else:
|
||||
init_width = max(min_dimension, math.sqrt(model_area * aspect))
|
||||
init_height = init_width / aspect
|
||||
|
||||
scaled_width, scaled_height = trim_to_multiple_of(
|
||||
math.floor(init_width), math.floor(init_height)
|
||||
)
|
||||
|
||||
else:
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
channels = self.latent_channels
|
||||
if channels == 9:
|
||||
channels = 4 # we don't really want noise for all the mask channels
|
||||
shape = (
|
||||
1,
|
||||
channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor,
|
||||
)
|
||||
if self.use_mps_noise or device.type == "mps":
|
||||
tensor = torch.empty(size=shape, device="cpu")
|
||||
tensor = self.get_noise_like(like=tensor).to(device)
|
||||
else:
|
||||
tensor = torch.empty(size=shape, device=device)
|
||||
tensor = self.get_noise_like(like=tensor)
|
||||
return tensor
|
@ -22,6 +22,10 @@ SAMPLER_CHOICES = [
|
||||
"dpmpp_2s_k",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"dpmpp_2m_sde",
|
||||
"dpmpp_2m_sde_k",
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
|
@ -556,8 +556,8 @@ class ModelPatcher:
|
||||
new_tokens_added = None
|
||||
|
||||
try:
|
||||
ti_manager = TextualInversionManager()
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
@ -650,22 +650,24 @@ class TextualInversionModel:
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = dict()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, token_ids: list[int]
|
||||
) -> list[int]:
|
||||
|
||||
#if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
# raise ValueError("token_ids must not start with bos_token_id")
|
||||
#if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
# raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
new_token_ids = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
|
@ -266,6 +266,8 @@ class ModelManager(object):
|
||||
for model_key, model_config in config.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
# alias for config file
|
||||
model_config["model_format"] = model_config.pop("format")
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
@ -445,38 +447,6 @@ class ModelManager(object):
|
||||
_cache = self.cache,
|
||||
)
|
||||
|
||||
def default_model(self) -> Optional[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
if none is defined.
|
||||
"""
|
||||
for model_key, model_config in self.models.items():
|
||||
if model_config.default:
|
||||
return self.parse_key(model_key)
|
||||
|
||||
for model_key, _ in self.models.items():
|
||||
return self.parse_key(model_key)
|
||||
else:
|
||||
return None # TODO: or redo as (None, None, None)
|
||||
|
||||
def set_default_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> None:
|
||||
"""
|
||||
Set the default model. The change will not take
|
||||
effect until you call model_manager.commit()
|
||||
"""
|
||||
|
||||
model_key = self.model_key(model_name, base_model, model_type)
|
||||
if model_key not in self.models:
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
for cur_model_key, config in self.models.items():
|
||||
config.default = cur_model_key == model_key
|
||||
|
||||
def model_info(
|
||||
self,
|
||||
model_name: str,
|
||||
@ -503,9 +473,9 @@ class ModelManager(object):
|
||||
self,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a dict of models, in format [base_model][model_type][model_name]
|
||||
Return a list of models.
|
||||
|
||||
Please use model_manager.models() to get all the model names,
|
||||
model_manager.model_info('model-name') to get the stanza for the model
|
||||
@ -513,7 +483,7 @@ class ModelManager(object):
|
||||
object derived from models.yaml
|
||||
"""
|
||||
|
||||
models = dict()
|
||||
models = []
|
||||
for model_key in sorted(self.models, key=str.casefold):
|
||||
model_config = self.models[model_key]
|
||||
|
||||
@ -523,18 +493,16 @@ class ModelManager(object):
|
||||
if model_type is not None and cur_model_type != model_type:
|
||||
continue
|
||||
|
||||
if cur_base_model not in models:
|
||||
models[cur_base_model] = dict()
|
||||
if cur_model_type not in models[cur_base_model]:
|
||||
models[cur_base_model][cur_model_type] = dict()
|
||||
|
||||
models[cur_base_model][cur_model_type][cur_model_name] = dict(
|
||||
model_dict = dict(
|
||||
**model_config.dict(exclude_defaults=True),
|
||||
# OpenAPIModelInfoBase
|
||||
name=cur_model_name,
|
||||
base_model=cur_base_model,
|
||||
type=cur_model_type,
|
||||
)
|
||||
|
||||
models.append(model_dict)
|
||||
|
||||
return models
|
||||
|
||||
def print_models(self) -> None:
|
||||
@ -646,7 +614,9 @@ class ModelManager(object):
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
if model_class.save_to_config:
|
||||
# TODO: or exclude_unset better fits here?
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True)
|
||||
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
|
||||
# alias for config file
|
||||
data_to_save[model_key]["format"] = data_to_save[model_key].pop("model_format")
|
||||
|
||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||
config_file_path = conf_file or self.config_path
|
||||
|
@ -1,3 +1,7 @@
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
@ -29,10 +33,63 @@ MODEL_CLASSES = {
|
||||
#},
|
||||
}
|
||||
|
||||
def get_all_model_configs():
|
||||
configs = set()
|
||||
for models in MODEL_CLASSES.values():
|
||||
for _, model in models.items():
|
||||
configs.update(model._get_configs().values())
|
||||
configs.discard(None)
|
||||
return list(configs) # TODO: set, list or tuple
|
||||
MODEL_CONFIGS = list()
|
||||
OPENAPI_MODEL_CONFIGS = list()
|
||||
|
||||
class OpenAPIModelInfoBase(BaseModel):
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
type: ModelType
|
||||
|
||||
|
||||
for base_model, models in MODEL_CLASSES.items():
|
||||
for model_type, model_class in models.items():
|
||||
model_configs = set(model_class._get_configs().values())
|
||||
model_configs.discard(None)
|
||||
MODEL_CONFIGS.extend(model_configs)
|
||||
|
||||
for cfg in model_configs:
|
||||
model_name, cfg_name = cfg.__qualname__.split('.')[-2:]
|
||||
openapi_cfg_name = model_name + cfg_name
|
||||
if openapi_cfg_name in vars():
|
||||
continue
|
||||
|
||||
api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict(
|
||||
__annotations__ = dict(
|
||||
type=Literal[model_type.value],
|
||||
),
|
||||
))
|
||||
|
||||
#globals()[openapi_cfg_name] = api_wrapper
|
||||
vars()[openapi_cfg_name] = api_wrapper
|
||||
OPENAPI_MODEL_CONFIGS.append(api_wrapper)
|
||||
|
||||
def get_model_config_enums():
|
||||
enums = list()
|
||||
|
||||
for model_config in MODEL_CONFIGS:
|
||||
fields = inspect.get_annotations(model_config)
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
raise Exception("format field not found")
|
||||
|
||||
# model_format: None
|
||||
# model_format: SomeModelFormat
|
||||
# model_format: Literal[SomeModelFormat.Diffusers]
|
||||
# model_format: Literal[SomeModelFormat.Diffusers, SomeModelFormat.Checkpoint]
|
||||
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
enums.append(field)
|
||||
|
||||
elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||
enums.append(type(field.__args__[0]))
|
||||
|
||||
elif field is None:
|
||||
pass
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported format definition in {model_configs.__qualname__}")
|
||||
|
||||
return enums
|
||||
|
||||
|
@ -48,12 +48,10 @@ class ModelError(str, Enum):
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
path: str # or Path
|
||||
#name: str # not included as present in model key
|
||||
description: Optional[str] = Field(None)
|
||||
format: Optional[str] = Field(None)
|
||||
default: Optional[bool] = Field(False)
|
||||
model_format: Optional[str] = Field(None)
|
||||
# do not save to config
|
||||
error: Optional[ModelError] = Field(None, exclude=True)
|
||||
error: Optional[ModelError] = Field(None)
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
@ -94,6 +92,11 @@ class ModelBase(metaclass=ABCMeta):
|
||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||
if len(subtypes) < 2:
|
||||
raise Exception("Invalid subfolder definition!")
|
||||
if all(t is None for t in subtypes):
|
||||
return None
|
||||
elif any(t is None for t in subtypes):
|
||||
raise Exception(f"Unsupported definition: {subtypes}")
|
||||
|
||||
if subtypes[0] in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[subtypes[0]]
|
||||
subtypes = subtypes[1:]
|
||||
@ -122,47 +125,41 @@ class ModelBase(metaclass=ABCMeta):
|
||||
continue
|
||||
|
||||
fields = inspect.get_annotations(value)
|
||||
if "format" not in fields:
|
||||
raise Exception("Invalid config definition - format field not found")
|
||||
try:
|
||||
field = fields["model_format"]
|
||||
except:
|
||||
raise Exception(f"Invalid config definition - format field not found({cls.__qualname__})")
|
||||
|
||||
format_type = typing.get_origin(fields["format"])
|
||||
if format_type not in {None, Literal, Union}:
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum):
|
||||
for model_format in field:
|
||||
configs[model_format.value] = value
|
||||
|
||||
if format_type is Union and not all(typing.get_origin(v) in {None, Literal} for v in fields["format"].__args__):
|
||||
raise Exception(f"Invalid config definition - unknown format type: {fields['format']}")
|
||||
elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__):
|
||||
for model_format in field.__args__:
|
||||
configs[model_format.value] = value
|
||||
|
||||
elif field is None:
|
||||
configs[None] = value
|
||||
|
||||
if format_type == Union:
|
||||
f_fields = fields["format"].__args__
|
||||
else:
|
||||
f_fields = (fields["format"],)
|
||||
|
||||
|
||||
for field in f_fields:
|
||||
if field is None:
|
||||
format_name = None
|
||||
else:
|
||||
format_name = field.__args__[0]
|
||||
|
||||
configs[format_name] = value # TODO: error when override(multiple)?
|
||||
|
||||
raise Exception(f"Unsupported format definition in {cls.__qualname__}")
|
||||
|
||||
cls.__configs = configs
|
||||
return cls.__configs
|
||||
|
||||
@classmethod
|
||||
def create_config(cls, **kwargs) -> ModelConfigBase:
|
||||
if "format" not in kwargs:
|
||||
raise Exception("Field 'format' not found in model config")
|
||||
if "model_format" not in kwargs:
|
||||
raise Exception("Field 'model_format' not found in model config")
|
||||
|
||||
configs = cls._get_configs()
|
||||
return configs[kwargs["format"]](**kwargs)
|
||||
return configs[kwargs["model_format"]](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs) -> ModelConfigBase:
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=cls.detect_format(path),
|
||||
model_format=cls.detect_format(path),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
@ -14,12 +15,16 @@ from .base import (
|
||||
classproperty,
|
||||
)
|
||||
|
||||
class ControlNetModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class ControlNetModel(ModelBase):
|
||||
#model_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
model_format: ControlNetModelFormat
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.ControlNet
|
||||
@ -69,9 +74,9 @@ class ControlNetModel(ModelBase):
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return "diffusers"
|
||||
return ControlNetModelFormat.Diffusers
|
||||
else:
|
||||
return "checkpoint"
|
||||
return ControlNetModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -81,7 +86,7 @@ class ControlNetModel(ModelBase):
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) != "diffusers":
|
||||
raise NotImlemetedError("Checkpoint controlnet models currently unsupported")
|
||||
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
|
||||
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
|
||||
else:
|
||||
return model_path
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
ModelBase,
|
||||
@ -12,11 +13,15 @@ from .base import (
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
|
||||
class LoRAModelFormat(str, Enum):
|
||||
LyCORIS = "lycoris"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class LoRAModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["lycoris"], Literal["diffusers"]]
|
||||
model_format: LoRAModelFormat # TODO:
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
@ -52,9 +57,9 @@ class LoRAModel(ModelBase):
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return "diffusers"
|
||||
return LoRAModelFormat.Diffusers
|
||||
else:
|
||||
return "lycoris"
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -64,7 +69,7 @@ class LoRAModel(ModelBase):
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == "diffusers":
|
||||
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||
# TODO: add diffusers lora when it stabilizes a bit
|
||||
raise NotImplementedError("Diffusers lora not supported")
|
||||
else:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
@ -19,16 +20,19 @@ from .base import (
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
class StableDiffusion1ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
@ -47,7 +51,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == "checkpoint":
|
||||
if model_format == StableDiffusion1ModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
@ -57,7 +61,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == "diffusers":
|
||||
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
@ -80,7 +84,7 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=model_format,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
@ -93,9 +97,9 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return "diffusers"
|
||||
return StableDiffusion1ModelFormat.Diffusers
|
||||
else:
|
||||
return "checkpoint"
|
||||
return StableDiffusion1ModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -116,19 +120,22 @@ class StableDiffusion1Model(DiffusersModel):
|
||||
else:
|
||||
return model_path
|
||||
|
||||
class StableDiffusion2ModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
||||
vae: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
variant: ModelVariantType
|
||||
@ -149,7 +156,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == "checkpoint":
|
||||
if model_format == StableDiffusion2ModelFormat.Checkpoint:
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
@ -159,7 +166,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == "diffusers":
|
||||
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
@ -191,7 +198,7 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=model_format,
|
||||
model_format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
@ -206,9 +213,9 @@ class StableDiffusion2Model(DiffusersModel):
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return "diffusers"
|
||||
return StableDiffusion2ModelFormat.Diffusers
|
||||
else:
|
||||
return "checkpoint"
|
||||
return StableDiffusion2ModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -281,8 +288,8 @@ def _convert_ckpt_and_cache(
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
|
||||
elif version == BaseModelType.StableDiffusion2:
|
||||
upcast_attention = config.upcast_attention
|
||||
prediction_type = config.prediction_type
|
||||
upcast_attention = model_config.upcast_attention
|
||||
prediction_type = model_config.prediction_type
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown model provided: {version}")
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
@ -15,7 +16,7 @@ class TextualInversionModel(ModelBase):
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: None
|
||||
model_format: None
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.TextualInversion
|
||||
|
@ -1,5 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
import safetensors
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Literal
|
||||
from .base import (
|
||||
@ -18,12 +20,16 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
class VaeModelFormat(str, Enum):
|
||||
Checkpoint = "checkpoint"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
class VaeModel(ModelBase):
|
||||
#vae_class: Type
|
||||
#model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
format: Union[Literal["checkpoint"], Literal["diffusers"]]
|
||||
model_format: VaeModelFormat
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Vae
|
||||
@ -70,9 +76,9 @@ class VaeModel(ModelBase):
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if os.path.isdir(path):
|
||||
return "diffusers"
|
||||
return VaeModelFormat.Diffusers
|
||||
else:
|
||||
return "checkpoint"
|
||||
return VaeModelFormat.Checkpoint
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -82,7 +88,7 @@ class VaeModel(ModelBase):
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) != "diffusers":
|
||||
if cls.detect_format(model_path) == VaeModelFormat.Checkpoint:
|
||||
return _convert_vae_ckpt_and_cache(
|
||||
weights_path=model_path,
|
||||
output_path=output_path,
|
||||
|
@ -1,9 +0,0 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.prompting
|
||||
"""
|
||||
from .conditioning import (
|
||||
get_prompt_structure,
|
||||
get_tokens_for_prompt_object,
|
||||
get_uc_and_c_and_ec,
|
||||
split_weighted_subprompts,
|
||||
)
|
@ -1,297 +0,0 @@
|
||||
"""
|
||||
This module handles the generation of the conditioning tensors.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||
|
||||
"""
|
||||
import re
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
PromptParser,
|
||||
Conjunction,
|
||||
)
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||
from ..util import torch_dtype
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string,
|
||||
model: InvokeAIDiffuserComponent,
|
||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
)
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||
if log_tokens or config.log_tokenization:
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None))
|
||||
return uc, c, ec
|
||||
|
||||
def get_prompt_structure(
|
||||
prompt_string, skip_normalize_legacy_blend: bool = False
|
||||
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
(
|
||||
positive_prompt_string,
|
||||
negative_prompt_string,
|
||||
) = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(
|
||||
positive_prompt_string, skip_normalize_legacy_blend
|
||||
)
|
||||
positive_prompt: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||
|
||||
return positive_prompt, negative_prompt
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
return max(
|
||||
[
|
||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
||||
for c in blend.prompts
|
||||
]
|
||||
)
|
||||
else:
|
||||
return len(
|
||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
||||
)
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||
) -> [str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError(
|
||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||
)
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (
|
||||
" ".join([f.text for f in x.original])
|
||||
if type(x) is CrossAttentionControlSubstitute
|
||||
else str(x)
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def split_prompt_to_positive_and_negative(prompt_string_uncleaned: str):
|
||||
unconditioned_words = ""
|
||||
unconditional_regex = r"\[(.*?)\]"
|
||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = " ".join(unconditionals)
|
||||
|
||||
# Remove Unconditioned Words From Prompt
|
||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
||||
clean_prompt = unconditional_regex_compile.sub(" ", prompt_string_uncleaned)
|
||||
prompt_string_cleaned = re.sub(" +", " ", clean_prompt)
|
||||
else:
|
||||
prompt_string_cleaned = prompt_string_uncleaned
|
||||
return prompt_string_cleaned, unconditioned_words
|
||||
|
||||
|
||||
def log_tokenization(
|
||||
positive_prompt: Union[Blend, FlattenedPrompt],
|
||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||
tokenizer,
|
||||
):
|
||||
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
|
||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||
log_tokenization_for_prompt_object(
|
||||
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
for i, c in enumerate(blend.prompts):
|
||||
log_tokenization_for_prompt_object(
|
||||
c,
|
||||
tokenizer,
|
||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
||||
)
|
||||
elif type(p) is FlattenedPrompt:
|
||||
flattened_prompt: FlattenedPrompt = p
|
||||
if flattened_prompt.wants_cross_attention_control:
|
||||
original_fragments = []
|
||||
edited_fragments = []
|
||||
for f in flattened_prompt.children:
|
||||
if type(f) is CrossAttentionControlSubstitute:
|
||||
original_fragments += f.original
|
||||
edited_fragments += f.edited
|
||||
else:
|
||||
original_fragments.append(f)
|
||||
edited_fragments.append(f)
|
||||
|
||||
original_text = " ".join([x.text for x in original_fragments])
|
||||
log_tokenization_for_text(
|
||||
original_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap originals)",
|
||||
)
|
||||
edited_text = " ".join([x.text for x in edited_fragments])
|
||||
log_tokenization_for_text(
|
||||
edited_text,
|
||||
tokenizer,
|
||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
||||
)
|
||||
else:
|
||||
text = " ".join([x.text for x in flattened_prompt.children])
|
||||
log_tokenization_for_text(
|
||||
text, tokenizer, display_label=display_label_prefix
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
"""
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace("</w>", " ")
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
else:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
|
||||
if usedTokens > 0:
|
||||
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||
logger.debug(f"{tokenized}\x1b[0m")
|
||||
|
||||
if discarded != "":
|
||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||
logger.debug(f"{discarded}\x1b[0m")
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
|
||||
pp = PromptParser()
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = []
|
||||
weights = []
|
||||
for i, x in enumerate(parsed_conjunctions):
|
||||
if len(x.prompts)>0:
|
||||
flattened_prompts.append(x.prompts[0])
|
||||
weights.append(weighted_subprompts[i][1])
|
||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
||||
"""
|
||||
Legacy blend parsing.
|
||||
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile(
|
||||
"""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
parsed_prompts = [
|
||||
(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1))
|
||||
for match in re.finditer(prompt_parser, text)
|
||||
]
|
||||
if len(parsed_prompts) == 0:
|
||||
return []
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
logger.warning(
|
||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||
)
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
from .diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
@ -10,4 +9,3 @@ from .diffusers_pipeline import (
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
@ -1,275 +0,0 @@
|
||||
"""
|
||||
Query and install embeddings from the HuggingFace SD Concepts Library
|
||||
at https://huggingface.co/sd-concepts-library.
|
||||
|
||||
The interface is through the Concepts() object.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from typing import Callable
|
||||
from urllib import error as ul_error
|
||||
from urllib import request
|
||||
|
||||
from huggingface_hub import (
|
||||
HfApi,
|
||||
HfFolder,
|
||||
ModelFilter,
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
"""
|
||||
Initialize the Concepts object. May optionally pass a root directory.
|
||||
"""
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.root = root or self.config.root
|
||||
self.hf_api = HfApi()
|
||||
self.local_concepts = dict()
|
||||
self.concept_list = None
|
||||
self.concepts_loaded = dict()
|
||||
self.triggers = dict() # concept name to trigger phrase
|
||||
self.concept_names = dict() # trigger phrase to concept name
|
||||
self.match_trigger = re.compile(
|
||||
"(<[\w\- >]+>)"
|
||||
) # trigger is slightly less restrictive than HF concept name
|
||||
self.match_concept = re.compile(
|
||||
"<([\w\-]+)>"
|
||||
) # HF concept name can only contain A-Za-z0-9_-
|
||||
|
||||
def list_concepts(self) -> list:
|
||||
"""
|
||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
||||
Also adds local concepts in invokeai/embeddings folder.
|
||||
"""
|
||||
local_concepts_now = self.get_local_concepts(
|
||||
os.path.join(self.root, "embeddings")
|
||||
)
|
||||
local_concepts_to_add = set(local_concepts_now).difference(
|
||||
set(self.local_concepts)
|
||||
)
|
||||
self.local_concepts.update(local_concepts_now)
|
||||
|
||||
if self.concept_list is not None:
|
||||
if local_concepts_to_add:
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
elif self.config.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
)
|
||||
self.concept_list = [a.id.split("/")[1] for a in models]
|
||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
)
|
||||
logger.warning(
|
||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
return self.concept_list
|
||||
|
||||
def get_concept_model_path(self, concept_name: str) -> str:
|
||||
"""
|
||||
Returns the path to the 'learned_embeds.bin' file in
|
||||
the named concept. Returns None if invalid or cannot
|
||||
be downloaded.
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
logger.warning(
|
||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
||||
|
||||
def concept_to_trigger(self, concept_name: str) -> str:
|
||||
"""
|
||||
Given a concept name returns its trigger by looking in the
|
||||
"token_identifier.txt" file.
|
||||
"""
|
||||
if concept_name in self.triggers:
|
||||
return self.triggers[concept_name]
|
||||
elif self.concept_is_local(concept_name):
|
||||
trigger = f"<{concept_name}>"
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
file = self.get_concept_file(
|
||||
concept_name, "token_identifier.txt", local_only=True
|
||||
)
|
||||
if not file:
|
||||
return None
|
||||
with open(file, "r") as f:
|
||||
trigger = f.readline()
|
||||
trigger = trigger.strip()
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
def trigger_to_concept(self, trigger: str) -> str:
|
||||
"""
|
||||
Given a trigger phrase, maps it to the concept library name.
|
||||
Only works if concept_to_trigger() has previously been called
|
||||
on this library. There needs to be a persistent database for
|
||||
this.
|
||||
"""
|
||||
concept = self.concept_names.get(trigger, None)
|
||||
return f"<{concept}>" if concept else f"{trigger}"
|
||||
|
||||
def replace_triggers_with_concepts(self, prompt: str) -> str:
|
||||
"""
|
||||
Given a prompt string that contains <trigger> tags, replace these
|
||||
tags with the concept name. The reason for this is so that the
|
||||
concept names get stored in the prompt metadata. There is no
|
||||
controlling of colliding triggers in the SD library, so it is
|
||||
better to store the concept name (unique) than the concept trigger
|
||||
(not necessarily unique!)
|
||||
"""
|
||||
if not prompt:
|
||||
return prompt
|
||||
triggers = self.match_trigger.findall(prompt)
|
||||
if not triggers:
|
||||
return prompt
|
||||
|
||||
def do_replace(match) -> str:
|
||||
return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(
|
||||
self,
|
||||
prompt: str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||
of `concepts_name` strings.
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
"""
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
return prompt
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match) -> str:
|
||||
if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
|
||||
return f"<{match.group(1)}>"
|
||||
return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
def get_concept_file(
|
||||
self,
|
||||
concept_name: str,
|
||||
file_name: str = "learned_embeds.bin",
|
||||
local_only: bool = False,
|
||||
) -> str:
|
||||
if not (
|
||||
self.concept_is_downloaded(concept_name)
|
||||
or self.concept_is_local(concept_name)
|
||||
or local_only
|
||||
):
|
||||
self.download_concept(concept_name)
|
||||
|
||||
# get local path in invokeai/embeddings if local concept
|
||||
if self.concept_is_local(concept_name):
|
||||
concept_path = self._concept_local_path(concept_name)
|
||||
path = concept_path
|
||||
else:
|
||||
concept_path = self._concept_path(concept_name)
|
||||
path = os.path.join(concept_path, file_name)
|
||||
return path if os.path.exists(path) else None
|
||||
|
||||
def concept_is_local(self, concept_name) -> bool:
|
||||
return concept_name in self.local_concepts
|
||||
|
||||
def concept_is_downloaded(self, concept_name) -> bool:
|
||||
concept_directory = self._concept_path(concept_name)
|
||||
return os.path.exists(concept_directory)
|
||||
|
||||
def download_concept(self, concept_name) -> bool:
|
||||
repo_id = self._concept_id(concept_name)
|
||||
dest = self._concept_path(concept_name)
|
||||
|
||||
access_token = HfFolder.get_token()
|
||||
header = [("Authorization", f"Bearer {access_token}")] if access_token else []
|
||||
opener = request.build_opener()
|
||||
opener.addheaders = header
|
||||
request.install_opener(opener)
|
||||
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
succeeded = True
|
||||
|
||||
bytes = 0
|
||||
|
||||
def tally_download_size(chunk, size, total):
|
||||
nonlocal bytes
|
||||
if chunk == 0:
|
||||
bytes += total
|
||||
|
||||
logger.info(f"Downloading {repo_id}...", end="")
|
||||
try:
|
||||
for file in (
|
||||
"README.md",
|
||||
"learned_embeds.bin",
|
||||
"token_identifier.txt",
|
||||
"type_of_concept.txt",
|
||||
):
|
||||
url = hf_hub_url(repo_id, file)
|
||||
request.urlretrieve(
|
||||
url, os.path.join(dest, file), reporthook=tally_download_size
|
||||
)
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
logger.warning(
|
||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
logger.error(
|
||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
||||
return succeeded
|
||||
|
||||
def _concept_id(self, concept_name: str) -> str:
|
||||
return f"sd-concepts-library/{concept_name}"
|
||||
|
||||
def _concept_path(self, concept_name: str) -> str:
|
||||
return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
|
||||
|
||||
def _concept_local_path(self, concept_name: str) -> str:
|
||||
filename = self.local_concepts[concept_name]
|
||||
return os.path.join(self.root, "embeddings", filename)
|
||||
|
||||
def get_local_concepts(self, loc_dir: str):
|
||||
locs_dic = dict()
|
||||
if os.path.isdir(loc_dir):
|
||||
for file in os.listdir(loc_dir):
|
||||
f = os.path.splitext(file)
|
||||
if f[1] == ".bin" or f[1] == ".pt":
|
||||
locs_dic[f[0]] = file
|
||||
return locs_dic
|
@ -16,7 +16,6 @@ from accelerate.utils import set_seed
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from compel import EmbeddingsProvider
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
@ -48,7 +47,6 @@ from .diffusion import (
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
@ -317,6 +315,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = "float32",
|
||||
control_model: ControlNetModel = None,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@ -341,22 +340,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
self.unet, self._unet_forward, is_running_diffusers=True
|
||||
)
|
||||
use_full_precision = precision == "float32" or precision == "autocast"
|
||||
self.textual_inversion_manager = TextualInversionManager(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision,
|
||||
)
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.embeddings_provider = EmbeddingsProvider(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
textual_inversion_manager=self.textual_inversion_manager,
|
||||
self.unet, self._unet_forward
|
||||
)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
self.control_model = control_model
|
||||
|
||||
@ -404,50 +391,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
def enable_offload_submodels(self, device: torch.device):
|
||||
"""
|
||||
Offload each submodel when it's not in use.
|
||||
|
||||
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
|
||||
the total available resource, and you want to free up as much for inference as possible.
|
||||
|
||||
This requires more moving parts and may add some delay as the U-Net is swapped out for the
|
||||
VAE and vice-versa.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = LazilyLoadedModelGroup(device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def disable_offload_submodels(self):
|
||||
"""
|
||||
Leave all submodels loaded.
|
||||
|
||||
Appropriate for cases where the size of the model in memory is small compared to the memory
|
||||
required for inference. Avoids the delay and complexity of shuffling the submodels to and
|
||||
from the GPU.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = FullyLoadedModelGroup(self._model_group.execution_device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def offload_all(self):
|
||||
"""Offload all this pipeline's models to CPU."""
|
||||
self._model_group.offload_current()
|
||||
|
||||
def ready(self):
|
||||
"""
|
||||
Ready this pipeline's models.
|
||||
|
||||
i.e. preload them to the GPU if appropriate.
|
||||
"""
|
||||
self._model_group.ready()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
@ -991,25 +934,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
device = self._model_group.device_for(self.safety_checker)
|
||||
return super().run_safety_checker(image, device, dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(
|
||||
self, c: List[List[str]], *, return_tokens=True, fragment_weights=None
|
||||
):
|
||||
"""
|
||||
Compatibility function for invokeai.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
|
||||
text_batch=c,
|
||||
fragment_weights_batch=fragment_weights,
|
||||
should_return_tokens=return_tokens,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.config.in_channels
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
self._model_group.load(self.vae)
|
||||
|
@ -18,7 +18,6 @@ from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
restore_default_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
@ -66,7 +65,6 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
model,
|
||||
model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool = False,
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
@ -75,7 +73,6 @@ class InvokeAIDiffuserComponent:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
@ -112,37 +109,6 @@ class InvokeAIDiffuserComponent:
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
# apparently unused code
|
||||
# TODO: delete
|
||||
# def override_cross_attention(
|
||||
# self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
# ) -> Dict[str, AttentionProcessor]:
|
||||
# """
|
||||
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
# the previous attention processor is returned so that the caller can restore it later.
|
||||
# """
|
||||
# self.conditioning = conditioning
|
||||
# self.cross_attention_control_context = Context(
|
||||
# arguments=self.conditioning.cross_attention_control_args,
|
||||
# step_count=step_count,
|
||||
# )
|
||||
# return override_cross_attention(
|
||||
# self.model,
|
||||
# self.cross_attention_control_context,
|
||||
# is_running_diffusers=self.is_running_diffusers,
|
||||
# )
|
||||
|
||||
def restore_default_cross_attention(
|
||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
||||
):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(
|
||||
self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor,
|
||||
)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
@ -204,9 +170,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
@ -264,9 +228,7 @@ class InvokeAIDiffuserComponent:
|
||||
total_step_count,
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
percent_through = step_index / total_step_count
|
||||
latents = self.apply_threshold(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
@ -275,22 +237,6 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = (
|
||||
step_index / total_step_count
|
||||
) # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(
|
||||
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||
)
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
return percent_through
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
@ -323,6 +269,7 @@ class InvokeAIDiffuserComponent:
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# TODO: looks unused
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
@ -350,34 +297,6 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
if self.is_running_diffusers:
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
@ -409,54 +328,6 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_save_attention_maps(ca_type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
||||
context.clear_requests(cleanup=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
# print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_apply_saved_attention_maps(ca_type)
|
||||
edited_conditioning = (
|
||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, edited_conditioning, **kwargs,
|
||||
)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
except:
|
||||
context.clear_requests(cleanup=True)
|
||||
raise
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||
|
@ -1,7 +1,7 @@
|
||||
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSDEScheduler
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict()),
|
||||
@ -21,5 +21,9 @@ SCHEDULER_MAP = dict(
|
||||
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
|
||||
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
|
@ -1,429 +0,0 @@
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@dataclass
|
||||
class EmbeddingInfo:
|
||||
name: str
|
||||
embedding: torch.Tensor
|
||||
num_vectors_per_token: int
|
||||
token_dim: int
|
||||
trained_steps: int = None
|
||||
trained_model_name: str = None
|
||||
trained_model_checksum: str = None
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
trigger_string: str
|
||||
embedding: torch.Tensor
|
||||
trigger_token_id: Optional[int] = None
|
||||
pad_token_ids: Optional[list[int]] = None
|
||||
|
||||
@property
|
||||
def embedding_vector_length(self) -> int:
|
||||
return self.embedding.shape[0]
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool = True,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||
self.trigger_to_sourcefile = dict()
|
||||
default_textual_inversions: list[TextualInversion] = []
|
||||
self.textual_inversions = default_textual_inversions
|
||||
|
||||
def load_huggingface_concepts(self, concepts: list[str]):
|
||||
for concept_name in concepts:
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if (
|
||||
self.has_textual_inversion_for_trigger_string(trigger)
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||
): # in case a token with literal angle brackets encountered
|
||||
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||
|
||||
def get_all_trigger_strings(self) -> list[str]:
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(
|
||||
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
|
||||
):
|
||||
ckpt_path = Path(ckpt_path)
|
||||
|
||||
if not ckpt_path.is_file():
|
||||
return
|
||||
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
|
||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||
for embedding_info in embedding_list:
|
||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||
logger.warning(
|
||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
)
|
||||
continue
|
||||
|
||||
# Resolve the situation in which an earlier embedding has claimed the same
|
||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||
trigger_str = embedding_info.name
|
||||
sourcefile = (
|
||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else ckpt_path.name
|
||||
)
|
||||
|
||||
if trigger_str in self.trigger_to_sourcefile:
|
||||
replacement_trigger_str = (
|
||||
f"<{ckpt_path.parent.name}>"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
)
|
||||
logger.info(
|
||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
try:
|
||||
self._add_textual_inversion(
|
||||
trigger_str,
|
||||
embedding_info.embedding,
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
# remember which source file claims this trigger
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
|
||||
except ValueError as e:
|
||||
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
) -> Optional[TextualInversion]:
|
||||
"""
|
||||
Add a textual inversion to be recognised.
|
||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
logger.warning(
|
||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
)
|
||||
return
|
||||
if not self.full_precision:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
elif len(embedding.shape) > 2:
|
||||
raise ValueError(
|
||||
f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
||||
)
|
||||
|
||||
try:
|
||||
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
|
||||
if not defer_injecting_tokens:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
self.textual_inversions.append(ti)
|
||||
return ti
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith("Warning"):
|
||||
logger.warning(f"{str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
logger.error(
|
||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
)
|
||||
raise
|
||||
|
||||
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
||||
if ti.trigger_token_id is not None:
|
||||
raise ValueError(
|
||||
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
|
||||
)
|
||||
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
|
||||
ti.trigger_string, ti.embedding[0]
|
||||
)
|
||||
|
||||
if ti.embedding_vector_length > 1:
|
||||
# for embeddings with vector length > 1
|
||||
pad_token_strings = [
|
||||
ti.trigger_string + "-!pad-" + str(pad_index)
|
||||
for pad_index in range(1, ti.embedding_vector_length)
|
||||
]
|
||||
# todo: batched UI for faster loading when vector length >2
|
||||
pad_token_ids = [
|
||||
self._get_or_create_token_id_and_assign_embedding(
|
||||
pad_token_str, ti.embedding[1 + i]
|
||||
)
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)
|
||||
]
|
||||
else:
|
||||
pad_token_ids = []
|
||||
|
||||
ti.trigger_token_id = trigger_token_id
|
||||
ti.pad_token_ids = pad_token_ids
|
||||
return ti.trigger_token_id
|
||||
|
||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||
try:
|
||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||
return ti is not None
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
def get_textual_inversion_for_trigger_string(
|
||||
self, trigger_string: str
|
||||
) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
|
||||
)
|
||||
|
||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
|
||||
)
|
||||
|
||||
def create_deferred_token_ids_for_any_trigger_terms(
|
||||
self, prompt_string: str
|
||||
) -> list[int]:
|
||||
injected_token_ids = []
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
logger.info(
|
||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
)
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
logger.debug(
|
||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
)
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
return injected_token_ids
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, prompt_token_ids: list[int]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||
|
||||
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
||||
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
||||
long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary.
|
||||
"""
|
||||
if len(prompt_token_ids) == 0:
|
||||
return prompt_token_ids
|
||||
|
||||
if prompt_token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||
textual_inversion_trigger_token_ids = [
|
||||
ti.trigger_token_id for ti in self.textual_inversions
|
||||
]
|
||||
prompt_token_ids = prompt_token_ids.copy()
|
||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||
if token_id in textual_inversion_trigger_token_ids:
|
||||
textual_inversion = next(
|
||||
ti
|
||||
for ti in self.textual_inversions
|
||||
if ti.trigger_token_id == token_id
|
||||
)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
|
||||
prompt_token_ids.insert(
|
||||
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
|
||||
)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _get_or_create_token_id_and_assign_embedding(
|
||||
self, token_str: str, embedding: torch.Tensor
|
||||
) -> int:
|
||||
if len(embedding.shape) != 1:
|
||||
raise ValueError(
|
||||
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
|
||||
)
|
||||
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if existing_token_id == self.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.text_encoder.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
# the following call is slow - todo make batched for better performance with vector length >1
|
||||
self.text_encoder.resize_token_embeddings(new_token_count)
|
||||
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id == self.tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
||||
if (
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
|
||||
!= embedding.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
|
||||
)
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
return token_id
|
||||
|
||||
|
||||
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
|
||||
suffix = Path(embedding_file).suffix
|
||||
try:
|
||||
if suffix in [".pt",".ckpt",".bin"]:
|
||||
scan_result = scan_file_path(embedding_file)
|
||||
if scan_result.infected_files > 0:
|
||||
logger.critical(
|
||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
logger.critical("For your safety, InvokeAI will not load this embed.")
|
||||
return list()
|
||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||
else:
|
||||
ckpt = safetensors.torch.load_file(embedding_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
return list()
|
||||
|
||||
# try to figure out what kind of embedding file it is and parse accordingly
|
||||
keys = list(ckpt.keys())
|
||||
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
||||
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
||||
|
||||
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
||||
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
||||
|
||||
elif 'emb_params' in keys:
|
||||
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
||||
|
||||
else:
|
||||
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
||||
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v1 embedding file: {basename}')
|
||||
|
||||
embeddings = list()
|
||||
token_counter = -1
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
if token_counter < 0:
|
||||
trigger = embedding_ckpt["name"]
|
||||
elif token_counter == 0:
|
||||
trigger = '<basename>'
|
||||
else:
|
||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||
token_counter += 1
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
trained_steps = embedding_ckpt["step"],
|
||||
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
|
||||
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v2 (
|
||||
self, embedding_ckpt: dict, file_path: str
|
||||
) -> List[EmbeddingInfo]:
|
||||
"""
|
||||
This handles embedding .pt file variant #2.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v2 embedding file: {basename}')
|
||||
embeddings = list()
|
||||
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
token_counter = 0
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
trigger = token if token != '*' \
|
||||
else f'<{basename}>' if token_counter == 0 \
|
||||
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
else:
|
||||
logger.warning(f"{basename}: Unrecognized embedding format")
|
||||
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v3 embedding file: {basename}')
|
||||
embedding = embedding_ckpt['emb_params']
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = f'<{basename}>',
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
return [embedding_info]
|
||||
|
||||
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 4' of the textual inversion embedding files. This one
|
||||
is usually associated with .bin files trained by HuggingFace diffusers.
|
||||
"""
|
||||
basename = Path(filepath).stem
|
||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||
|
||||
logger.debug(f'Loading v4 embedding file: {short_path}')
|
||||
|
||||
embeddings = list()
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
logger.warning(f"Invalid embeddings file: {short_path}")
|
||||
else:
|
||||
for token,embedding in embedding_ckpt.items():
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = token or f"<{basename}>",
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = 1, # All Concepts seem to default to 1
|
||||
token_dim = embedding.size()[0],
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
@ -20,6 +20,10 @@ SAMPLER_CHOICES = [
|
||||
"dpmpp_2s_k",
|
||||
"dpmpp_2m",
|
||||
"dpmpp_2m_k",
|
||||
"dpmpp_2m_sde",
|
||||
"dpmpp_2m_sde_k",
|
||||
"dpmpp_sde",
|
||||
"dpmpp_sde_k",
|
||||
"unipc",
|
||||
]
|
||||
|
||||
|
@ -23,6 +23,8 @@ import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import { useListModelsQuery } from 'services/apiSlice';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
@ -45,6 +47,18 @@ const App = ({
|
||||
|
||||
const isApplicationReady = useIsApplicationReady();
|
||||
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
model_type: 'pipeline',
|
||||
});
|
||||
const { data: controlnetModels } = useListModelsQuery({
|
||||
model_type: 'controlnet',
|
||||
});
|
||||
const { data: vaeModels } = useListModelsQuery({ model_type: 'vae' });
|
||||
const { data: loraModels } = useListModelsQuery({ model_type: 'lora' });
|
||||
const { data: embeddingModels } = useListModelsQuery({
|
||||
model_type: 'embedding',
|
||||
});
|
||||
|
||||
const [loadingOverridden, setLoadingOverridden] = useState(false);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
@ -143,6 +157,7 @@ const App = ({
|
||||
</Portal>
|
||||
</Grid>
|
||||
<DeleteImageModal />
|
||||
<UpdateImageBoardModal />
|
||||
<Toaster />
|
||||
<GlobalHotkeys />
|
||||
</>
|
||||
|
@ -21,6 +21,8 @@ import {
|
||||
DeleteImageContext,
|
||||
DeleteImageContextProvider,
|
||||
} from 'app/contexts/DeleteImageContext';
|
||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
||||
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
|
||||
|
||||
const App = lazy(() => import('./App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
|
||||
@ -76,11 +78,13 @@ const InvokeAIUI = ({
|
||||
<ThemeLocaleProvider>
|
||||
<ImageDndContext>
|
||||
<DeleteImageContextProvider>
|
||||
<AddImageToBoardContextProvider>
|
||||
<App
|
||||
config={config}
|
||||
headerComponent={headerComponent}
|
||||
setIsReady={setIsReady}
|
||||
/>
|
||||
</AddImageToBoardContextProvider>
|
||||
</DeleteImageContextProvider>
|
||||
</ImageDndContext>
|
||||
</ThemeLocaleProvider>
|
||||
|
@ -9,6 +9,8 @@ export const SCHEDULER_NAMES_AS_CONST = [
|
||||
'ddpm',
|
||||
'dpmpp_2s',
|
||||
'dpmpp_2m',
|
||||
'dpmpp_2m_sde',
|
||||
'dpmpp_sde',
|
||||
'heun',
|
||||
'kdpm_2',
|
||||
'lms',
|
||||
@ -17,6 +19,8 @@ export const SCHEDULER_NAMES_AS_CONST = [
|
||||
'euler_k',
|
||||
'dpmpp_2s_k',
|
||||
'dpmpp_2m_k',
|
||||
'dpmpp_2m_sde_k',
|
||||
'dpmpp_sde_k',
|
||||
'heun_k',
|
||||
'lms_k',
|
||||
'euler_a',
|
||||
@ -32,16 +36,20 @@ export const SCHEDULER_LABEL_MAP: Record<SchedulerParam, string> = {
|
||||
deis: 'DEIS',
|
||||
ddim: 'DDIM',
|
||||
ddpm: 'DDPM',
|
||||
dpmpp_sde: 'DPM++ SDE',
|
||||
dpmpp_2s: 'DPM++ 2S',
|
||||
dpmpp_2m: 'DPM++ 2M',
|
||||
dpmpp_2m_sde: 'DPM++ 2M SDE',
|
||||
heun: 'Heun',
|
||||
kdpm_2: 'KDPM 2',
|
||||
lms: 'LMS',
|
||||
pndm: 'PNDM',
|
||||
unipc: 'UniPC',
|
||||
euler_k: 'Euler Karras',
|
||||
dpmpp_sde_k: 'DPM++ SDE Karras',
|
||||
dpmpp_2s_k: 'DPM++ 2S Karras',
|
||||
dpmpp_2m_k: 'DPM++ 2M Karras',
|
||||
dpmpp_2m_sde_k: 'DPM++ 2M SDE Karras',
|
||||
heun_k: 'Heun Karras',
|
||||
lms_k: 'LMS Karras',
|
||||
euler_a: 'Euler Ancestral',
|
||||
|
@ -0,0 +1,89 @@
|
||||
import { useDisclosure } from '@chakra-ui/react';
|
||||
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { useAddImageToBoardMutation } from 'services/apiSlice';
|
||||
|
||||
export type ImageUsage = {
|
||||
isInitialImage: boolean;
|
||||
isCanvasImage: boolean;
|
||||
isNodesImage: boolean;
|
||||
isControlNetImage: boolean;
|
||||
};
|
||||
|
||||
type AddImageToBoardContextValue = {
|
||||
/**
|
||||
* Whether the move image dialog is open.
|
||||
*/
|
||||
isOpen: boolean;
|
||||
/**
|
||||
* Closes the move image dialog.
|
||||
*/
|
||||
onClose: () => void;
|
||||
/**
|
||||
* The image pending movement
|
||||
*/
|
||||
image?: ImageDTO;
|
||||
onClickAddToBoard: (image: ImageDTO) => void;
|
||||
handleAddToBoard: (boardId: string) => void;
|
||||
};
|
||||
|
||||
export const AddImageToBoardContext =
|
||||
createContext<AddImageToBoardContextValue>({
|
||||
isOpen: false,
|
||||
onClose: () => undefined,
|
||||
onClickAddToBoard: () => undefined,
|
||||
handleAddToBoard: () => undefined,
|
||||
});
|
||||
|
||||
type Props = PropsWithChildren;
|
||||
|
||||
export const AddImageToBoardContextProvider = (props: Props) => {
|
||||
const [imageToMove, setImageToMove] = useState<ImageDTO>();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const [addImageToBoard, result] = useAddImageToBoardMutation();
|
||||
|
||||
// Clean up after deleting or dismissing the modal
|
||||
const closeAndClearImageToDelete = useCallback(() => {
|
||||
setImageToMove(undefined);
|
||||
onClose();
|
||||
}, [onClose]);
|
||||
|
||||
const onClickAddToBoard = useCallback(
|
||||
(image?: ImageDTO) => {
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
setImageToMove(image);
|
||||
onOpen();
|
||||
},
|
||||
[setImageToMove, onOpen]
|
||||
);
|
||||
|
||||
const handleAddToBoard = useCallback(
|
||||
(boardId: string) => {
|
||||
if (imageToMove) {
|
||||
addImageToBoard({
|
||||
board_id: boardId,
|
||||
image_name: imageToMove.image_name,
|
||||
});
|
||||
closeAndClearImageToDelete();
|
||||
}
|
||||
},
|
||||
[addImageToBoard, closeAndClearImageToDelete, imageToMove]
|
||||
);
|
||||
|
||||
return (
|
||||
<AddImageToBoardContext.Provider
|
||||
value={{
|
||||
isOpen,
|
||||
image: imageToMove,
|
||||
onClose: closeAndClearImageToDelete,
|
||||
onClickAddToBoard,
|
||||
handleAddToBoard,
|
||||
}}
|
||||
>
|
||||
{props.children}
|
||||
</AddImageToBoardContext.Provider>
|
||||
);
|
||||
};
|
@ -35,25 +35,23 @@ export const selectImageUsage = createSelector(
|
||||
(state: RootState, image_name?: string) => image_name,
|
||||
],
|
||||
(generation, canvas, nodes, controlNet, image_name) => {
|
||||
const isInitialImage = generation.initialImage?.image_name === image_name;
|
||||
const isInitialImage = generation.initialImage?.imageName === image_name;
|
||||
|
||||
const isCanvasImage = canvas.layerState.objects.some(
|
||||
(obj) => obj.kind === 'image' && obj.image.image_name === image_name
|
||||
(obj) => obj.kind === 'image' && obj.imageName === image_name
|
||||
);
|
||||
|
||||
const isNodesImage = nodes.nodes.some((node) => {
|
||||
return some(
|
||||
node.data.inputs,
|
||||
(input) =>
|
||||
input.type === 'image' && input.value?.image_name === image_name
|
||||
(input) => input.type === 'image' && input.value === image_name
|
||||
);
|
||||
});
|
||||
|
||||
const isControlNetImage = some(
|
||||
controlNet.controlNets,
|
||||
(c) =>
|
||||
c.controlImage?.image_name === image_name ||
|
||||
c.processedControlImage?.image_name === image_name
|
||||
c.controlImage === image_name || c.processedControlImage === image_name
|
||||
);
|
||||
|
||||
const imageUsage: ImageUsage = {
|
||||
|
@ -5,7 +5,6 @@ import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersist
|
||||
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
|
||||
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
|
||||
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
|
||||
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
|
||||
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
|
||||
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
|
||||
import { omit } from 'lodash-es';
|
||||
@ -18,7 +17,6 @@ const serializationDenylist: {
|
||||
gallery: galleryPersistDenylist,
|
||||
generation: generationPersistDenylist,
|
||||
lightbox: lightboxPersistDenylist,
|
||||
models: modelsPersistDenylist,
|
||||
nodes: nodesPersistDenylist,
|
||||
postprocessing: postprocessingPersistDenylist,
|
||||
system: systemPersistDenylist,
|
||||
|
@ -7,7 +7,6 @@ import { initialNodesState } from 'features/nodes/store/nodesSlice';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
|
||||
import { initialConfigState } from 'features/system/store/configSlice';
|
||||
import { initialModelsState } from 'features/system/store/modelSlice';
|
||||
import { initialSystemState } from 'features/system/store/systemSlice';
|
||||
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
|
||||
import { initialUIState } from 'features/ui/store/uiSlice';
|
||||
@ -21,7 +20,6 @@ const initialStates: {
|
||||
gallery: initialGalleryState,
|
||||
generation: initialGenerationState,
|
||||
lightbox: initialLightboxState,
|
||||
models: initialModelsState,
|
||||
nodes: initialNodesState,
|
||||
postprocessing: initialPostprocessingState,
|
||||
system: initialSystemState,
|
||||
|
@ -73,6 +73,15 @@ import { addImageCategoriesChangedListener } from './listeners/imageCategoriesCh
|
||||
import { addControlNetImageProcessedListener } from './listeners/controlNetImageProcessed';
|
||||
import { addControlNetAutoProcessListener } from './listeners/controlNetAutoProcess';
|
||||
import { addUpdateImageUrlsOnConnectListener } from './listeners/updateImageUrlsOnConnect';
|
||||
import {
|
||||
addImageAddedToBoardFulfilledListener,
|
||||
addImageAddedToBoardRejectedListener,
|
||||
} from './listeners/imageAddedToBoard';
|
||||
import { addBoardIdSelectedListener } from './listeners/boardIdSelected';
|
||||
import {
|
||||
addImageRemovedFromBoardFulfilledListener,
|
||||
addImageRemovedFromBoardRejectedListener,
|
||||
} from './listeners/imageRemovedFromBoard';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -92,6 +101,12 @@ export type AppListenerEffect = ListenerEffect<
|
||||
AppDispatch
|
||||
>;
|
||||
|
||||
/**
|
||||
* The RTK listener middleware is a lightweight alternative sagas/observables.
|
||||
*
|
||||
* Most side effect logic should live in a listener.
|
||||
*/
|
||||
|
||||
// Image uploaded
|
||||
addImageUploadedFulfilledListener();
|
||||
addImageUploadedRejectedListener();
|
||||
@ -183,3 +198,10 @@ addControlNetAutoProcessListener();
|
||||
|
||||
// Update image URLs on connect
|
||||
addUpdateImageUrlsOnConnectListener();
|
||||
|
||||
// Boards
|
||||
addImageAddedToBoardFulfilledListener();
|
||||
addImageAddedToBoardRejectedListener();
|
||||
addImageRemovedFromBoardFulfilledListener();
|
||||
addImageRemovedFromBoardRejectedListener();
|
||||
addBoardIdSelectedListener();
|
||||
|
@ -0,0 +1,99 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { boardIdSelected } from 'features/gallery/store/boardSlice';
|
||||
import { selectImagesAll } from 'features/gallery/store/imagesSlice';
|
||||
import { IMAGES_PER_PAGE, receivedPageOfImages } from 'services/thunks/image';
|
||||
import { api } from 'services/apiSlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
export const addBoardIdSelectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: boardIdSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const boardId = action.payload;
|
||||
|
||||
// we need to check if we need to fetch more images
|
||||
|
||||
const state = getState();
|
||||
const allImages = selectImagesAll(state);
|
||||
|
||||
if (!boardId) {
|
||||
// a board was unselected
|
||||
dispatch(imageSelected(allImages[0]?.image_name));
|
||||
return;
|
||||
}
|
||||
|
||||
const { categories } = state.images;
|
||||
|
||||
const filteredImages = allImages.filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
|
||||
return isInCategory && isInSelectedBoard;
|
||||
});
|
||||
|
||||
// get the board from the cache
|
||||
const { data: boards } = api.endpoints.listAllBoards.select()(state);
|
||||
const board = boards?.find((b) => b.board_id === boardId);
|
||||
|
||||
if (!board) {
|
||||
// can't find the board in cache...
|
||||
dispatch(imageSelected(allImages[0]?.image_name));
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(imageSelected(board.cover_image_name));
|
||||
|
||||
// if we haven't loaded one full page of images from this board, load more
|
||||
if (
|
||||
filteredImages.length < board.image_count &&
|
||||
filteredImages.length < IMAGES_PER_PAGE
|
||||
) {
|
||||
dispatch(receivedPageOfImages({ categories, boardId }));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addBoardIdSelected_changeSelectedImage_listener = () => {
|
||||
startAppListening({
|
||||
actionCreator: boardIdSelected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const boardId = action.payload;
|
||||
|
||||
const state = getState();
|
||||
|
||||
// we need to check if we need to fetch more images
|
||||
|
||||
if (!boardId) {
|
||||
// a board was unselected - we don't need to do anything
|
||||
return;
|
||||
}
|
||||
|
||||
const { categories } = state.images;
|
||||
|
||||
const filteredImages = selectImagesAll(state).filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
const isInSelectedBoard = boardId ? i.board_id === boardId : true;
|
||||
return isInCategory && isInSelectedBoard;
|
||||
});
|
||||
|
||||
// get the board from the cache
|
||||
const { data: boards } = api.endpoints.listAllBoards.select()(state);
|
||||
const board = boards?.find((b) => b.board_id === boardId);
|
||||
if (!board) {
|
||||
// can't find the board in cache...
|
||||
return;
|
||||
}
|
||||
|
||||
// if we haven't loaded one full page of images from this board, load more
|
||||
if (
|
||||
filteredImages.length < board.image_count &&
|
||||
filteredImages.length < IMAGES_PER_PAGE
|
||||
) {
|
||||
dispatch(receivedPageOfImages({ categories, boardId }));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -34,7 +34,7 @@ export const addControlNetImageProcessedListener = () => {
|
||||
[controlNet.processorNode.id]: {
|
||||
...controlNet.processorNode,
|
||||
is_intermediate: true,
|
||||
image: pick(controlNet.controlImage, ['image_name']),
|
||||
image: { image_name: controlNet.controlImage },
|
||||
},
|
||||
},
|
||||
};
|
||||
@ -81,7 +81,7 @@ export const addControlNetImageProcessedListener = () => {
|
||||
dispatch(
|
||||
controlNetProcessedImageChanged({
|
||||
controlNetId,
|
||||
processedControlImage,
|
||||
processedControlImage: processedControlImage.image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
@ -0,0 +1,40 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/thunks/image';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
export const addImageAddedToBoardFulfilledListener = () => {
|
||||
startAppListening({
|
||||
matcher: api.endpoints.addImageToBoard.matchFulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Image added to board'
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
imageName: image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addImageAddedToBoardRejectedListener = () => {
|
||||
startAppListening({
|
||||
matcher: api.endpoints.addImageToBoard.matchRejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Problem adding image to board'
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -12,12 +12,16 @@ export const addImageCategoriesChangedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: imageCategoriesChanged,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const filteredImagesCount = selectFilteredImagesAsArray(
|
||||
getState()
|
||||
).length;
|
||||
const state = getState();
|
||||
const filteredImagesCount = selectFilteredImagesAsArray(state).length;
|
||||
|
||||
if (!filteredImagesCount) {
|
||||
dispatch(receivedPageOfImages());
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: action.payload,
|
||||
boardId: state.boards.selectedBoardId,
|
||||
})
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -6,15 +6,15 @@ import { clamp } from 'lodash-es';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageRemoved,
|
||||
selectImagesEntities,
|
||||
selectImagesIds,
|
||||
} from 'features/gallery/store/imagesSlice';
|
||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||
const moduleLog = log.child({ namespace: 'image' });
|
||||
|
||||
/**
|
||||
* Called when the user requests an image deletion
|
||||
@ -22,7 +22,7 @@ const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
|
||||
export const addRequestedImageDeletionListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: requestedImageDeletion,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch, getState, condition }) => {
|
||||
const { image, imageUsage } = action.payload;
|
||||
|
||||
const { image_name } = image;
|
||||
@ -30,9 +30,8 @@ export const addRequestedImageDeletionListener = () => {
|
||||
const state = getState();
|
||||
const selectedImage = state.gallery.selectedImage;
|
||||
|
||||
if (selectedImage && selectedImage.image_name === image_name) {
|
||||
if (selectedImage === image_name) {
|
||||
const ids = selectImagesIds(state);
|
||||
const entities = selectImagesEntities(state);
|
||||
|
||||
const deletedImageIndex = ids.findIndex(
|
||||
(result) => result.toString() === image_name
|
||||
@ -48,10 +47,8 @@ export const addRequestedImageDeletionListener = () => {
|
||||
|
||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
||||
|
||||
const newSelectedImage = entities[newSelectedImageId];
|
||||
|
||||
if (newSelectedImageId) {
|
||||
dispatch(imageSelected(newSelectedImage));
|
||||
dispatch(imageSelected(newSelectedImageId as string));
|
||||
} else {
|
||||
dispatch(imageSelected());
|
||||
}
|
||||
@ -79,7 +76,21 @@ export const addRequestedImageDeletionListener = () => {
|
||||
dispatch(imageRemoved(image_name));
|
||||
|
||||
// Delete from server
|
||||
dispatch(imageDeleted({ imageName: image_name }));
|
||||
const { requestId } = dispatch(imageDeleted({ imageName: image_name }));
|
||||
|
||||
// Wait for successful deletion, then trigger boards to re-fetch
|
||||
const wasImageDeleted = await condition(
|
||||
(action): action is ReturnType<typeof imageDeleted.fulfilled> =>
|
||||
imageDeleted.fulfilled.match(action) &&
|
||||
action.meta.requestId === requestId,
|
||||
30000
|
||||
);
|
||||
|
||||
if (wasImageDeleted) {
|
||||
dispatch(
|
||||
api.util.invalidateTags([{ type: 'Board', id: image.board_id }])
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -0,0 +1,40 @@
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { startAppListening } from '..';
|
||||
import { imageMetadataReceived } from 'services/thunks/image';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'boards' });
|
||||
|
||||
export const addImageRemovedFromBoardFulfilledListener = () => {
|
||||
startAppListening({
|
||||
matcher: api.endpoints.removeImageFromBoard.matchFulfilled,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Image added to board'
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imageMetadataReceived({
|
||||
imageName: image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addImageRemovedFromBoardRejectedListener = () => {
|
||||
startAppListening({
|
||||
matcher: api.endpoints.removeImageFromBoard.matchRejected,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const { board_id, image_name } = action.meta.arg.originalArgs;
|
||||
|
||||
moduleLog.debug(
|
||||
{ data: { board_id, image_name } },
|
||||
'Problem adding image to board'
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -46,7 +46,12 @@ export const addImageUploadedFulfilledListener = () => {
|
||||
|
||||
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
|
||||
const { controlNetId } = postUploadAction;
|
||||
dispatch(controlNetImageChanged({ controlNetId, controlImage: image }));
|
||||
dispatch(
|
||||
controlNetImageChanged({
|
||||
controlNetId,
|
||||
controlImage: image.image_name,
|
||||
})
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,8 @@
|
||||
import { startAppListening } from '../..';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
import { receivedModels } from 'services/thunks/model';
|
||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
|
||||
@ -15,16 +14,17 @@ export const addSocketConnectedEventListener = () => {
|
||||
|
||||
moduleLog.debug({ timestamp }, 'Connected');
|
||||
|
||||
const { models, nodes, config, images } = getState();
|
||||
const { nodes, config, images } = getState();
|
||||
|
||||
const { disabledTabs } = config;
|
||||
|
||||
if (!images.ids.length) {
|
||||
dispatch(receivedPageOfImages());
|
||||
}
|
||||
|
||||
if (!models.ids.length) {
|
||||
dispatch(receivedModels());
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories: ['general'],
|
||||
isIntermediate: false,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
if (!nodes.schema && !disabledTabs.includes('nodes')) {
|
||||
|
@ -9,6 +9,7 @@ import { imageMetadataReceived } from 'services/thunks/image';
|
||||
import { sessionCanceled } from 'services/thunks/session';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'socketio' });
|
||||
const nodeDenylist = ['dataURL_image'];
|
||||
@ -24,7 +25,8 @@ export const addInvocationCompleteEventListener = () => {
|
||||
|
||||
const sessionId = action.payload.data.graph_execution_state_id;
|
||||
|
||||
const { cancelType, isCancelScheduled } = getState().system;
|
||||
const { cancelType, isCancelScheduled, boardIdToAddTo } =
|
||||
getState().system;
|
||||
|
||||
// Handle scheduled cancelation
|
||||
if (cancelType === 'scheduled' && isCancelScheduled) {
|
||||
@ -57,6 +59,15 @@ export const addInvocationCompleteEventListener = () => {
|
||||
dispatch(addImageToStagingArea(imageDTO));
|
||||
}
|
||||
|
||||
if (boardIdToAddTo && !imageDTO.is_intermediate) {
|
||||
dispatch(
|
||||
api.endpoints.addImageToBoard.initiate({
|
||||
board_id: boardIdToAddTo,
|
||||
image_name,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(progressImageSet(null));
|
||||
}
|
||||
// pass along the socket event as an application action
|
||||
|
@ -22,15 +22,15 @@ const selectAllUsedImages = createSelector(
|
||||
selectImagesEntities,
|
||||
],
|
||||
(generation, canvas, nodes, controlNet, imageEntities) => {
|
||||
const allUsedImages: ImageDTO[] = [];
|
||||
const allUsedImages: string[] = [];
|
||||
|
||||
if (generation.initialImage) {
|
||||
allUsedImages.push(generation.initialImage);
|
||||
allUsedImages.push(generation.initialImage.imageName);
|
||||
}
|
||||
|
||||
canvas.layerState.objects.forEach((obj) => {
|
||||
if (obj.kind === 'image') {
|
||||
allUsedImages.push(obj.image);
|
||||
allUsedImages.push(obj.imageName);
|
||||
}
|
||||
});
|
||||
|
||||
@ -53,7 +53,7 @@ const selectAllUsedImages = createSelector(
|
||||
|
||||
forEach(imageEntities, (image) => {
|
||||
if (image) {
|
||||
allUsedImages.push(image);
|
||||
allUsedImages.push(image.image_name);
|
||||
}
|
||||
});
|
||||
|
||||
@ -80,7 +80,7 @@ export const addUpdateImageUrlsOnConnectListener = () => {
|
||||
`Fetching new image URLs for ${allUsedImages.length} images`
|
||||
);
|
||||
|
||||
allUsedImages.forEach(({ image_name }) => {
|
||||
allUsedImages.forEach((image_name) => {
|
||||
dispatch(
|
||||
imageUrlsReceived({
|
||||
imageName: image_name,
|
||||
|
@ -1,11 +1,10 @@
|
||||
import { startAppListening } from '..';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { imageUpdated, imageUploaded } from 'services/thunks/image';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Graph } from 'services/api';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import {
|
||||
canvasSessionIdChanged,
|
||||
stagingAreaInitialized,
|
||||
@ -67,112 +66,106 @@ export const addUserInvokedCanvasListener = () => {
|
||||
|
||||
moduleLog.debug(`Generation mode: ${generationMode}`);
|
||||
|
||||
// Build the canvas graph
|
||||
const graphComponents = await buildCanvasGraphComponents(
|
||||
state,
|
||||
generationMode
|
||||
);
|
||||
// Temp placeholders for the init and mask images
|
||||
let canvasInitImage: ImageDTO | undefined;
|
||||
let canvasMaskImage: ImageDTO | undefined;
|
||||
|
||||
if (!graphComponents) {
|
||||
moduleLog.error('Problem building graph');
|
||||
return;
|
||||
}
|
||||
|
||||
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
|
||||
|
||||
// Assemble! Note that this graph *does not have the init or mask image set yet!*
|
||||
const nodes: Graph['nodes'] = {
|
||||
[rangeNode.id]: rangeNode,
|
||||
[iterateNode.id]: iterateNode,
|
||||
[baseNode.id]: baseNode,
|
||||
};
|
||||
|
||||
const graph = { nodes, edges };
|
||||
|
||||
dispatch(canvasGraphBuilt(graph));
|
||||
|
||||
moduleLog.debug({ data: graph }, 'Canvas graph built');
|
||||
|
||||
// If we are generating img2img or inpaint, we need to upload the init images
|
||||
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
|
||||
const baseFilename = `${uuidv4()}.png`;
|
||||
dispatch(
|
||||
// For img2img and inpaint/outpaint, we need to upload the init images
|
||||
if (['img2img', 'inpaint', 'outpaint'].includes(generationMode)) {
|
||||
// upload the image, saving the request id
|
||||
const { requestId: initImageUploadedRequestId } = dispatch(
|
||||
imageUploaded({
|
||||
formData: {
|
||||
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
|
||||
file: new File([baseBlob], 'canvasInitImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
},
|
||||
imageCategory: 'general',
|
||||
isIntermediate: true,
|
||||
})
|
||||
);
|
||||
|
||||
// Wait for the image to be uploaded
|
||||
const [{ payload: baseImageDTO }] = await take(
|
||||
// Wait for the image to be uploaded, matching by request id
|
||||
const [{ payload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === baseFilename
|
||||
action.meta.requestId === initImageUploadedRequestId
|
||||
);
|
||||
|
||||
// Update the base node with the image name and type
|
||||
baseNode.image = {
|
||||
image_name: baseImageDTO.image_name,
|
||||
};
|
||||
canvasInitImage = payload;
|
||||
}
|
||||
|
||||
// For inpaint, we also need to upload the mask layer
|
||||
if (baseNode.type === 'inpaint') {
|
||||
const maskFilename = `${uuidv4()}.png`;
|
||||
dispatch(
|
||||
// For inpaint/outpaint, we also need to upload the mask layer
|
||||
if (['inpaint', 'outpaint'].includes(generationMode)) {
|
||||
// upload the image, saving the request id
|
||||
const { requestId: maskImageUploadedRequestId } = dispatch(
|
||||
imageUploaded({
|
||||
formData: {
|
||||
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
|
||||
file: new File([maskBlob], 'canvasMaskImage.png', {
|
||||
type: 'image/png',
|
||||
}),
|
||||
},
|
||||
imageCategory: 'mask',
|
||||
isIntermediate: true,
|
||||
})
|
||||
);
|
||||
|
||||
// Wait for the mask to be uploaded
|
||||
const [{ payload: maskImageDTO }] = await take(
|
||||
// Wait for the image to be uploaded, matching by request id
|
||||
const [{ payload }] = await take(
|
||||
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
|
||||
imageUploaded.fulfilled.match(action) &&
|
||||
action.meta.arg.formData.file.name === maskFilename
|
||||
action.meta.requestId === maskImageUploadedRequestId
|
||||
);
|
||||
|
||||
// Update the base node with the image name and type
|
||||
baseNode.mask = {
|
||||
image_name: maskImageDTO.image_name,
|
||||
};
|
||||
canvasMaskImage = payload;
|
||||
}
|
||||
|
||||
// Create the session and wait for response
|
||||
dispatch(sessionCreated({ graph }));
|
||||
const [sessionCreatedAction] = await take(sessionCreated.fulfilled.match);
|
||||
const graph = buildCanvasGraph(
|
||||
state,
|
||||
generationMode,
|
||||
canvasInitImage,
|
||||
canvasMaskImage
|
||||
);
|
||||
|
||||
moduleLog.debug({ graph }, `Canvas graph built`);
|
||||
|
||||
// currently this action is just listened to for logging
|
||||
dispatch(canvasGraphBuilt(graph));
|
||||
|
||||
// Create the session, store the request id
|
||||
const { requestId: sessionCreatedRequestId } = dispatch(
|
||||
sessionCreated({ graph })
|
||||
);
|
||||
|
||||
// Take the session created action, matching by its request id
|
||||
const [sessionCreatedAction] = await take(
|
||||
(action): action is ReturnType<typeof sessionCreated.fulfilled> =>
|
||||
sessionCreated.fulfilled.match(action) &&
|
||||
action.meta.requestId === sessionCreatedRequestId
|
||||
);
|
||||
const sessionId = sessionCreatedAction.payload.id;
|
||||
|
||||
// Associate the init image with the session, now that we have the session ID
|
||||
if (
|
||||
(baseNode.type === 'img2img' || baseNode.type === 'inpaint') &&
|
||||
baseNode.image
|
||||
) {
|
||||
if (['img2img', 'inpaint'].includes(generationMode) && canvasInitImage) {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: baseNode.image.image_name,
|
||||
imageName: canvasInitImage.image_name,
|
||||
requestBody: { session_id: sessionId },
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Associate the mask image with the session, now that we have the session ID
|
||||
if (baseNode.type === 'inpaint' && baseNode.mask) {
|
||||
if (['inpaint'].includes(generationMode) && canvasMaskImage) {
|
||||
dispatch(
|
||||
imageUpdated({
|
||||
imageName: baseNode.mask.image_name,
|
||||
imageName: canvasMaskImage.image_name,
|
||||
requestBody: { session_id: sessionId },
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Prep the canvas staging area if it is not yet initialized
|
||||
if (!state.canvas.layerState.stagingArea.boundingBox) {
|
||||
dispatch(
|
||||
stagingAreaInitialized({
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { startAppListening } from '..';
|
||||
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
@ -15,7 +15,7 @@ export const addUserInvokedImageToImageListener = () => {
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildImageToImageGraph(state);
|
||||
const graph = buildLinearImageToImageGraph(state);
|
||||
dispatch(imageToImageGraphBuilt(graph));
|
||||
moduleLog.debug({ data: graph }, 'Image to Image graph built');
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import { startAppListening } from '..';
|
||||
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
|
||||
import { sessionCreated } from 'services/thunks/session';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'invoke' });
|
||||
|
||||
@ -15,7 +15,7 @@ export const addUserInvokedTextToImageListener = () => {
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const state = getState();
|
||||
|
||||
const graph = buildTextToImageGraph(state);
|
||||
const graph = buildLinearTextToImageGraph(state);
|
||||
|
||||
dispatch(textToImageGraphBuilt(graph));
|
||||
|
||||
|
@ -5,40 +5,39 @@ import {
|
||||
configureStore,
|
||||
} from '@reduxjs/toolkit';
|
||||
|
||||
import { rememberReducer, rememberEnhancer } from 'redux-remember';
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import { rememberEnhancer, rememberReducer } from 'redux-remember';
|
||||
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import imagesReducer from 'features/gallery/store/imagesSlice';
|
||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
// import sessionReducer from 'features/system/store/sessionSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import modelsReducer from 'features/system/store/modelSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
|
||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||
|
||||
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
|
||||
|
||||
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
|
||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||
import { serialize } from './enhancers/reduxRemember/serialize';
|
||||
import { unserialize } from './enhancers/reduxRemember/unserialize';
|
||||
import { LOCALSTORAGE_PREFIX } from './constants';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
const allReducers = {
|
||||
canvas: canvasReducer,
|
||||
gallery: galleryReducer,
|
||||
generation: generationReducer,
|
||||
lightbox: lightboxReducer,
|
||||
models: modelsReducer,
|
||||
nodes: nodesReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
system: systemReducer,
|
||||
@ -47,7 +46,9 @@ const allReducers = {
|
||||
hotkeys: hotkeysReducer,
|
||||
images: imagesReducer,
|
||||
controlNet: controlNetReducer,
|
||||
boards: boardsReducer,
|
||||
// session: sessionReducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
const rootReducer = combineReducers(allReducers);
|
||||
@ -59,12 +60,12 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'gallery',
|
||||
'generation',
|
||||
'lightbox',
|
||||
// 'models',
|
||||
'nodes',
|
||||
'postprocessing',
|
||||
'system',
|
||||
'ui',
|
||||
'controlNet',
|
||||
// 'boards',
|
||||
// 'hotkeys',
|
||||
// 'config',
|
||||
];
|
||||
@ -84,6 +85,7 @@ export const store = configureStore({
|
||||
immutableCheck: false,
|
||||
serializableCheck: false,
|
||||
})
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
.prepend(listenerMiddleware.middleware),
|
||||
devTools: {
|
||||
|
@ -9,7 +9,7 @@ import {
|
||||
import { useDraggable, useDroppable } from '@dnd-kit/core';
|
||||
import { useCombinedRefs } from '@dnd-kit/utilities';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import { ReactElement, SyntheticEvent, useCallback } from 'react';
|
||||
@ -53,7 +53,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
isDropDisabled = false,
|
||||
isDragDisabled = false,
|
||||
isUploadDisabled = false,
|
||||
fallback = <IAIImageFallback />,
|
||||
fallback = <IAIImageLoadingFallback />,
|
||||
payloadImage,
|
||||
minSize = 24,
|
||||
postUploadAction,
|
||||
|
@ -1,10 +1,20 @@
|
||||
import { Flex, FlexProps, Spinner, SpinnerProps } from '@chakra-ui/react';
|
||||
import {
|
||||
As,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Icon,
|
||||
IconProps,
|
||||
Spinner,
|
||||
SpinnerProps,
|
||||
} from '@chakra-ui/react';
|
||||
import { ReactElement } from 'react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
|
||||
type Props = FlexProps & {
|
||||
spinnerProps?: SpinnerProps;
|
||||
};
|
||||
|
||||
export const IAIImageFallback = (props: Props) => {
|
||||
export const IAIImageLoadingFallback = (props: Props) => {
|
||||
const { spinnerProps, ...rest } = props;
|
||||
const { sx, ...restFlexProps } = rest;
|
||||
return (
|
||||
@ -25,3 +35,35 @@ export const IAIImageFallback = (props: Props) => {
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type IAINoImageFallbackProps = {
|
||||
flexProps?: FlexProps;
|
||||
iconProps?: IconProps;
|
||||
as?: As;
|
||||
};
|
||||
|
||||
export const IAINoImageFallback = (props: IAINoImageFallbackProps) => {
|
||||
const { sx: flexSx, ...restFlexProps } = props.flexProps ?? { sx: {} };
|
||||
const { sx: iconSx, ...restIconProps } = props.iconProps ?? { sx: {} };
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
bg: 'base.900',
|
||||
opacity: 0.7,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
...flexSx,
|
||||
}}
|
||||
{...restFlexProps}
|
||||
>
|
||||
<Icon
|
||||
as={props.as ?? FaImage}
|
||||
sx={{ color: 'base.700', ...iconSx }}
|
||||
{...restIconProps}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -1,14 +1,21 @@
|
||||
import { Image } from 'react-konva';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { Image, Rect } from 'react-konva';
|
||||
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||
import useImage from 'use-image';
|
||||
import { CanvasImage } from '../store/canvasTypes';
|
||||
|
||||
type IAICanvasImageProps = {
|
||||
url: string;
|
||||
x: number;
|
||||
y: number;
|
||||
canvasImage: CanvasImage;
|
||||
};
|
||||
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
||||
const { url, x, y } = props;
|
||||
const [image] = useImage(url, 'anonymous');
|
||||
const { width, height, x, y, imageName } = props.canvasImage;
|
||||
const { data: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
|
||||
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
|
||||
|
||||
if (!imageDTO) {
|
||||
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
|
||||
}
|
||||
|
||||
return <Image x={x} y={y} image={image} listening={false} />;
|
||||
};
|
||||
|
||||
|
@ -39,14 +39,7 @@ const IAICanvasObjectRenderer = () => {
|
||||
<Group name="outpainting-objects" listening={false}>
|
||||
{objects.map((obj, i) => {
|
||||
if (isCanvasBaseImage(obj)) {
|
||||
return (
|
||||
<IAICanvasImage
|
||||
key={i}
|
||||
x={obj.x}
|
||||
y={obj.y}
|
||||
url={obj.image.image_url}
|
||||
/>
|
||||
);
|
||||
return <IAICanvasImage key={i} canvasImage={obj} />;
|
||||
} else if (isCanvasBaseLine(obj)) {
|
||||
const line = (
|
||||
<Line
|
||||
|
@ -59,11 +59,7 @@ const IAICanvasStagingArea = (props: Props) => {
|
||||
return (
|
||||
<Group {...rest}>
|
||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||
<IAICanvasImage
|
||||
url={currentStagingAreaImage.image.image_url}
|
||||
x={x}
|
||||
y={y}
|
||||
/>
|
||||
<IAICanvasImage canvasImage={currentStagingAreaImage} />
|
||||
)}
|
||||
{shouldShowStagingOutline && (
|
||||
<Group>
|
||||
|
@ -203,7 +203,7 @@ export const canvasSlice = createSlice({
|
||||
y: 0,
|
||||
width: width,
|
||||
height: height,
|
||||
image: image,
|
||||
imageName: image.image_name,
|
||||
},
|
||||
],
|
||||
};
|
||||
@ -325,7 +325,7 @@ export const canvasSlice = createSlice({
|
||||
kind: 'image',
|
||||
layer: 'base',
|
||||
...state.layerState.stagingArea.boundingBox,
|
||||
image,
|
||||
imageName: image.image_name,
|
||||
});
|
||||
|
||||
state.layerState.stagingArea.selectedImageIndex =
|
||||
@ -865,25 +865,25 @@ export const canvasSlice = createSlice({
|
||||
state.doesCanvasNeedScaling = true;
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
state.layerState.objects.forEach((object) => {
|
||||
if (object.kind === 'image') {
|
||||
if (object.image.image_name === image_name) {
|
||||
object.image.image_url = image_url;
|
||||
object.image.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
}
|
||||
});
|
||||
// state.layerState.objects.forEach((object) => {
|
||||
// if (object.kind === 'image') {
|
||||
// if (object.image.image_name === image_name) {
|
||||
// object.image.image_url = image_url;
|
||||
// object.image.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
|
||||
state.layerState.stagingArea.images.forEach((stagedImage) => {
|
||||
if (stagedImage.image.image_name === image_name) {
|
||||
stagedImage.image.image_url = image_url;
|
||||
stagedImage.image.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
});
|
||||
});
|
||||
// state.layerState.stagingArea.images.forEach((stagedImage) => {
|
||||
// if (stagedImage.image.image_name === image_name) {
|
||||
// stagedImage.image.image_url = image_url;
|
||||
// stagedImage.image.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// });
|
||||
// });
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -38,7 +38,7 @@ export type CanvasImage = {
|
||||
y: number;
|
||||
width: number;
|
||||
height: number;
|
||||
image: ImageDTO;
|
||||
imageName: string;
|
||||
};
|
||||
|
||||
export type CanvasMaskLine = {
|
||||
|
@ -11,9 +11,11 @@ import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { FaUndo } from 'react-icons/fa';
|
||||
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
|
||||
const selector = createSelector(
|
||||
controlNetSelector,
|
||||
@ -31,24 +33,45 @@ type Props = {
|
||||
|
||||
const ControlNetImagePreview = (props: Props) => {
|
||||
const { imageSx } = props;
|
||||
const { controlNetId, controlImage, processedControlImage, processorType } =
|
||||
props.controlNet;
|
||||
const {
|
||||
controlNetId,
|
||||
controlImage: controlImageName,
|
||||
processedControlImage: processedControlImageName,
|
||||
processorType,
|
||||
} = props.controlNet;
|
||||
const dispatch = useAppDispatch();
|
||||
const { pendingControlImages } = useAppSelector(selector);
|
||||
|
||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||
|
||||
const {
|
||||
data: controlImage,
|
||||
isLoading: isLoadingControlImage,
|
||||
isError: isErrorControlImage,
|
||||
isSuccess: isSuccessControlImage,
|
||||
} = useGetImageDTOQuery(controlImageName ?? skipToken);
|
||||
|
||||
const {
|
||||
data: processedControlImage,
|
||||
isLoading: isLoadingProcessedControlImage,
|
||||
isError: isErrorProcessedControlImage,
|
||||
isSuccess: isSuccessProcessedControlImage,
|
||||
} = useGetImageDTOQuery(processedControlImageName ?? skipToken);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (controlImage?.image_name === droppedImage.image_name) {
|
||||
if (controlImageName === droppedImage.image_name) {
|
||||
return;
|
||||
}
|
||||
setIsMouseOverImage(false);
|
||||
dispatch(
|
||||
controlNetImageChanged({ controlNetId, controlImage: droppedImage })
|
||||
controlNetImageChanged({
|
||||
controlNetId,
|
||||
controlImage: droppedImage.image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
[controlImage, controlNetId, dispatch]
|
||||
[controlImageName, controlNetId, dispatch]
|
||||
);
|
||||
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
@ -150,7 +173,7 @@ const ControlNetImagePreview = (props: Props) => {
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<IAIImageFallback />
|
||||
<IAIImageLoadingFallback />
|
||||
</Box>
|
||||
)}
|
||||
{controlImage && (
|
||||
|
@ -39,8 +39,8 @@ export type ControlNetConfig = {
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
controlImage: ImageDTO | null;
|
||||
processedControlImage: ImageDTO | null;
|
||||
controlImage: string | null;
|
||||
processedControlImage: string | null;
|
||||
processorType: ControlNetProcessorType;
|
||||
processorNode: RequiredControlNetProcessorNode;
|
||||
shouldAutoConfig: boolean;
|
||||
@ -80,7 +80,7 @@ export const controlNetSlice = createSlice({
|
||||
},
|
||||
controlNetAddedFromImage: (
|
||||
state,
|
||||
action: PayloadAction<{ controlNetId: string; controlImage: ImageDTO }>
|
||||
action: PayloadAction<{ controlNetId: string; controlImage: string }>
|
||||
) => {
|
||||
const { controlNetId, controlImage } = action.payload;
|
||||
state.controlNets[controlNetId] = {
|
||||
@ -108,7 +108,7 @@ export const controlNetSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
controlImage: ImageDTO | null;
|
||||
controlImage: string | null;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, controlImage } = action.payload;
|
||||
@ -125,7 +125,7 @@ export const controlNetSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
processedControlImage: ImageDTO | null;
|
||||
processedControlImage: string | null;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, processedControlImage } = action.payload;
|
||||
@ -260,30 +260,30 @@ export const controlNetSlice = createSlice({
|
||||
// Preemptively remove the image from the gallery
|
||||
const { imageName } = action.meta.arg;
|
||||
forEach(state.controlNets, (c) => {
|
||||
if (c.controlImage?.image_name === imageName) {
|
||||
if (c.controlImage === imageName) {
|
||||
c.controlImage = null;
|
||||
c.processedControlImage = null;
|
||||
}
|
||||
if (c.processedControlImage?.image_name === imageName) {
|
||||
if (c.processedControlImage === imageName) {
|
||||
c.processedControlImage = null;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
forEach(state.controlNets, (c) => {
|
||||
if (c.controlImage?.image_name === image_name) {
|
||||
c.controlImage.image_url = image_url;
|
||||
c.controlImage.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
if (c.processedControlImage?.image_name === image_name) {
|
||||
c.processedControlImage.image_url = image_url;
|
||||
c.processedControlImage.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
});
|
||||
});
|
||||
// forEach(state.controlNets, (c) => {
|
||||
// if (c.controlImage?.image_name === image_name) {
|
||||
// c.controlImage.image_url = image_url;
|
||||
// c.controlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// if (c.processedControlImage?.image_name === image_name) {
|
||||
// c.processedControlImage.image_url = image_url;
|
||||
// c.processedControlImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// });
|
||||
// });
|
||||
|
||||
builder.addCase(appSocketInvocationError, (state, action) => {
|
||||
state.pendingControlImages = [];
|
||||
|
@ -0,0 +1,27 @@
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { useCallback } from 'react';
|
||||
import { useCreateBoardMutation } from 'services/apiSlice';
|
||||
|
||||
const DEFAULT_BOARD_NAME = 'My Board';
|
||||
|
||||
const AddBoardButton = () => {
|
||||
const [createBoard, { isLoading }] = useCreateBoardMutation();
|
||||
|
||||
const handleCreateBoard = useCallback(() => {
|
||||
createBoard(DEFAULT_BOARD_NAME);
|
||||
}, [createBoard]);
|
||||
|
||||
return (
|
||||
<IAIButton
|
||||
isLoading={isLoading}
|
||||
aria-label="Add Board"
|
||||
onClick={handleCreateBoard}
|
||||
size="sm"
|
||||
sx={{ px: 4 }}
|
||||
>
|
||||
Add Board
|
||||
</IAIButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default AddBoardButton;
|
@ -0,0 +1,93 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { FaImages } from 'react-icons/fa';
|
||||
import { boardIdSelected } from '../../store/boardSlice';
|
||||
import { useDispatch } from 'react-redux';
|
||||
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import { SelectedItemOverlay } from '../SelectedItemOverlay';
|
||||
import { useCallback } from 'react';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
|
||||
import { useDroppable } from '@dnd-kit/core';
|
||||
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||
|
||||
const AllImagesBoard = ({ isSelected }: { isSelected: boolean }) => {
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const handleAllImagesBoardClick = () => {
|
||||
dispatch(boardIdSelected());
|
||||
};
|
||||
|
||||
const [removeImageFromBoard, { isLoading }] =
|
||||
useRemoveImageFromBoardMutation();
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (!droppedImage.board_id) {
|
||||
return;
|
||||
}
|
||||
removeImageFromBoard({
|
||||
board_id: droppedImage.board_id,
|
||||
image_name: droppedImage.image_name,
|
||||
});
|
||||
},
|
||||
[removeImageFromBoard]
|
||||
);
|
||||
|
||||
const {
|
||||
isOver,
|
||||
setNodeRef,
|
||||
active: isDropActive,
|
||||
} = useDroppable({
|
||||
id: `board_droppable_all_images`,
|
||||
data: {
|
||||
handleDrop,
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
cursor: 'pointer',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
onClick={handleAllImagesBoardClick}
|
||||
>
|
||||
<Flex
|
||||
ref={setNodeRef}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
aspectRatio: '1/1',
|
||||
}}
|
||||
>
|
||||
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaImages} />
|
||||
<AnimatePresence>
|
||||
{isSelected && <SelectedItemOverlay />}
|
||||
</AnimatePresence>
|
||||
<AnimatePresence>
|
||||
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||
</AnimatePresence>
|
||||
</Flex>
|
||||
<Text
|
||||
sx={{
|
||||
color: isSelected ? 'base.50' : 'base.200',
|
||||
fontWeight: isSelected ? 600 : undefined,
|
||||
fontSize: 'xs',
|
||||
}}
|
||||
>
|
||||
All Images
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default AllImagesBoard;
|
@ -0,0 +1,134 @@
|
||||
import {
|
||||
Collapse,
|
||||
Flex,
|
||||
Grid,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
} from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import {
|
||||
boardsSelector,
|
||||
setBoardSearchText,
|
||||
} from 'features/gallery/store/boardSlice';
|
||||
import { memo, useState } from 'react';
|
||||
import HoverableBoard from './HoverableBoard';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import AddBoardButton from './AddBoardButton';
|
||||
import AllImagesBoard from './AllImagesBoard';
|
||||
import { CloseIcon } from '@chakra-ui/icons';
|
||||
import { useListAllBoardsQuery } from 'services/apiSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
[boardsSelector],
|
||||
(boardsState) => {
|
||||
const { selectedBoardId, searchText } = boardsState;
|
||||
return { selectedBoardId, searchText };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
type Props = {
|
||||
isOpen: boolean;
|
||||
};
|
||||
|
||||
const BoardsList = (props: Props) => {
|
||||
const { isOpen } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { selectedBoardId, searchText } = useAppSelector(selector);
|
||||
|
||||
const { data: boards } = useListAllBoardsQuery();
|
||||
|
||||
const filteredBoards = searchText
|
||||
? boards?.filter((board) =>
|
||||
board.board_name.toLowerCase().includes(searchText.toLowerCase())
|
||||
)
|
||||
: boards;
|
||||
|
||||
const [searchMode, setSearchMode] = useState(false);
|
||||
|
||||
const handleBoardSearch = (searchTerm: string) => {
|
||||
setSearchMode(searchTerm.length > 0);
|
||||
dispatch(setBoardSearchText(searchTerm));
|
||||
};
|
||||
const clearBoardSearch = () => {
|
||||
setSearchMode(false);
|
||||
dispatch(setBoardSearchText(''));
|
||||
};
|
||||
|
||||
return (
|
||||
<Collapse in={isOpen} animateOpacity>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
gap: 2,
|
||||
bg: 'base.800',
|
||||
borderRadius: 'base',
|
||||
p: 2,
|
||||
mt: 2,
|
||||
}}
|
||||
>
|
||||
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||
<InputGroup>
|
||||
<Input
|
||||
placeholder="Search Boards..."
|
||||
value={searchText}
|
||||
onChange={(e) => {
|
||||
handleBoardSearch(e.target.value);
|
||||
}}
|
||||
/>
|
||||
{searchText && searchText.length && (
|
||||
<InputRightElement>
|
||||
<IconButton
|
||||
onClick={clearBoardSearch}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
aria-label="Clear Search"
|
||||
icon={<CloseIcon boxSize={3} />}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
<AddBoardButton />
|
||||
</Flex>
|
||||
<OverlayScrollbarsComponent
|
||||
defer
|
||||
style={{ height: '100%', width: '100%' }}
|
||||
options={{
|
||||
scrollbars: {
|
||||
visibility: 'auto',
|
||||
autoHide: 'move',
|
||||
autoHideDelay: 1300,
|
||||
theme: 'os-theme-dark',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Grid
|
||||
className="list-container"
|
||||
sx={{
|
||||
gap: 2,
|
||||
gridTemplateRows: '5.5rem 5.5rem',
|
||||
gridAutoFlow: 'column dense',
|
||||
gridAutoColumns: '4rem',
|
||||
}}
|
||||
>
|
||||
{!searchMode && <AllImagesBoard isSelected={!selectedBoardId} />}
|
||||
{filteredBoards &&
|
||||
filteredBoards.map((board) => (
|
||||
<HoverableBoard
|
||||
key={board.board_id}
|
||||
board={board}
|
||||
isSelected={selectedBoardId === board.board_id}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Flex>
|
||||
</Collapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BoardsList);
|
@ -0,0 +1,193 @@
|
||||
import {
|
||||
Badge,
|
||||
Box,
|
||||
Editable,
|
||||
EditableInput,
|
||||
EditablePreview,
|
||||
Flex,
|
||||
Image,
|
||||
MenuItem,
|
||||
MenuList,
|
||||
} from '@chakra-ui/react';
|
||||
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { FaFolder, FaTrash } from 'react-icons/fa';
|
||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||
import { BoardDTO, ImageDTO } from 'services/api';
|
||||
import { IAINoImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { boardIdSelected } from 'features/gallery/store/boardSlice';
|
||||
import {
|
||||
useAddImageToBoardMutation,
|
||||
useDeleteBoardMutation,
|
||||
useGetImageDTOQuery,
|
||||
useUpdateBoardMutation,
|
||||
} from 'services/apiSlice';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { useDroppable } from '@dnd-kit/core';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import IAIDropOverlay from 'common/components/IAIDropOverlay';
|
||||
import { SelectedItemOverlay } from '../SelectedItemOverlay';
|
||||
|
||||
interface HoverableBoardProps {
|
||||
board: BoardDTO;
|
||||
isSelected: boolean;
|
||||
}
|
||||
|
||||
const HoverableBoard = memo(({ board, isSelected }: HoverableBoardProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data: coverImage } = useGetImageDTOQuery(
|
||||
board.cover_image_name ?? skipToken
|
||||
);
|
||||
|
||||
const { board_name, board_id } = board;
|
||||
|
||||
const handleSelectBoard = useCallback(() => {
|
||||
dispatch(boardIdSelected(board_id));
|
||||
}, [board_id, dispatch]);
|
||||
|
||||
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
|
||||
useUpdateBoardMutation();
|
||||
|
||||
const [deleteBoard, { isLoading: isDeleteBoardLoading }] =
|
||||
useDeleteBoardMutation();
|
||||
|
||||
const [addImageToBoard, { isLoading: isAddImageToBoardLoading }] =
|
||||
useAddImageToBoardMutation();
|
||||
|
||||
const handleUpdateBoardName = (newBoardName: string) => {
|
||||
updateBoard({ board_id, changes: { board_name: newBoardName } });
|
||||
};
|
||||
|
||||
const handleDeleteBoard = useCallback(() => {
|
||||
deleteBoard(board_id);
|
||||
}, [board_id, deleteBoard]);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (droppedImage.board_id === board_id) {
|
||||
return;
|
||||
}
|
||||
addImageToBoard({ board_id, image_name: droppedImage.image_name });
|
||||
},
|
||||
[addImageToBoard, board_id]
|
||||
);
|
||||
|
||||
const {
|
||||
isOver,
|
||||
setNodeRef,
|
||||
active: isDropActive,
|
||||
} = useDroppable({
|
||||
id: `board_droppable_${board_id}`,
|
||||
data: {
|
||||
handleDrop,
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<Box sx={{ touchAction: 'none' }}>
|
||||
<ContextMenu<HTMLDivElement>
|
||||
menuProps={{ size: 'sm', isLazy: true }}
|
||||
renderMenu={() => (
|
||||
<MenuList sx={{ visibility: 'visible !important' }}>
|
||||
<MenuItem
|
||||
sx={{ color: 'error.300' }}
|
||||
icon={<FaTrash />}
|
||||
onClickCapture={handleDeleteBoard}
|
||||
>
|
||||
Delete Board
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
)}
|
||||
>
|
||||
{(ref) => (
|
||||
<Flex
|
||||
key={board_id}
|
||||
userSelect="none"
|
||||
ref={ref}
|
||||
sx={{
|
||||
flexDir: 'column',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
cursor: 'pointer',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
}}
|
||||
>
|
||||
<Flex
|
||||
ref={setNodeRef}
|
||||
onClick={handleSelectBoard}
|
||||
sx={{
|
||||
position: 'relative',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
aspectRatio: '1/1',
|
||||
overflow: 'hidden',
|
||||
}}
|
||||
>
|
||||
{board.cover_image_name && coverImage?.image_url && (
|
||||
<Image src={coverImage?.image_url} draggable={false} />
|
||||
)}
|
||||
{!(board.cover_image_name && coverImage?.image_url) && (
|
||||
<IAINoImageFallback iconProps={{ boxSize: 8 }} as={FaFolder} />
|
||||
)}
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
insetInlineEnd: 0,
|
||||
top: 0,
|
||||
p: 1,
|
||||
}}
|
||||
>
|
||||
<Badge variant="solid">{board.image_count}</Badge>
|
||||
</Flex>
|
||||
<AnimatePresence>
|
||||
{isSelected && <SelectedItemOverlay />}
|
||||
</AnimatePresence>
|
||||
<AnimatePresence>
|
||||
{isDropActive && <IAIDropOverlay isOver={isOver} />}
|
||||
</AnimatePresence>
|
||||
</Flex>
|
||||
|
||||
<Box sx={{ width: 'full' }}>
|
||||
<Editable
|
||||
defaultValue={board_name}
|
||||
submitOnBlur={false}
|
||||
onSubmit={(nextValue) => {
|
||||
handleUpdateBoardName(nextValue);
|
||||
}}
|
||||
>
|
||||
<EditablePreview
|
||||
sx={{
|
||||
color: isSelected ? 'base.50' : 'base.200',
|
||||
fontWeight: isSelected ? 600 : undefined,
|
||||
fontSize: 'xs',
|
||||
textAlign: 'center',
|
||||
p: 0,
|
||||
}}
|
||||
noOfLines={1}
|
||||
/>
|
||||
<EditableInput
|
||||
sx={{
|
||||
color: 'base.50',
|
||||
fontSize: 'xs',
|
||||
borderColor: 'base.500',
|
||||
p: 0,
|
||||
outline: 0,
|
||||
}}
|
||||
/>
|
||||
</Editable>
|
||||
</Box>
|
||||
</Flex>
|
||||
)}
|
||||
</ContextMenu>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
HoverableBoard.displayName = 'HoverableBoard';
|
||||
|
||||
export default HoverableBoard;
|
@ -0,0 +1,93 @@
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogBody,
|
||||
AlertDialogContent,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogOverlay,
|
||||
Box,
|
||||
Flex,
|
||||
Spinner,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
|
||||
import { memo, useContext, useRef, useState } from 'react';
|
||||
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { useListAllBoardsQuery } from 'services/apiSlice';
|
||||
|
||||
const UpdateImageBoardModal = () => {
|
||||
// const boards = useSelector(selectBoardsAll);
|
||||
const { data: boards, isFetching } = useListAllBoardsQuery();
|
||||
const { isOpen, onClose, handleAddToBoard, image } = useContext(
|
||||
AddImageToBoardContext
|
||||
);
|
||||
const [selectedBoard, setSelectedBoard] = useState<string | null>();
|
||||
|
||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||
|
||||
const currentBoard = boards?.find(
|
||||
(board) => board.board_id === image?.board_id
|
||||
);
|
||||
|
||||
return (
|
||||
<AlertDialog
|
||||
isOpen={isOpen}
|
||||
leastDestructiveRef={cancelRef}
|
||||
onClose={onClose}
|
||||
isCentered
|
||||
>
|
||||
<AlertDialogOverlay>
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
{currentBoard ? 'Move Image to Board' : 'Add Image to Board'}
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
<Box>
|
||||
<Flex direction="column" gap={3}>
|
||||
{currentBoard && (
|
||||
<Text>
|
||||
Moving this image from{' '}
|
||||
<strong>{currentBoard.board_name}</strong> to
|
||||
</Text>
|
||||
)}
|
||||
{isFetching ? (
|
||||
<Spinner />
|
||||
) : (
|
||||
<IAIMantineSelect
|
||||
placeholder="Select Board"
|
||||
onChange={(v) => setSelectedBoard(v)}
|
||||
value={selectedBoard}
|
||||
data={(boards ?? []).map((board) => ({
|
||||
label: board.board_name,
|
||||
value: board.board_id,
|
||||
}))}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Box>
|
||||
</AlertDialogBody>
|
||||
<AlertDialogFooter>
|
||||
<IAIButton onClick={onClose}>Cancel</IAIButton>
|
||||
<IAIButton
|
||||
isDisabled={!selectedBoard}
|
||||
colorScheme="accent"
|
||||
onClick={() => {
|
||||
if (selectedBoard) {
|
||||
handleAddToBoard(selectedBoard);
|
||||
}
|
||||
}}
|
||||
ml={3}
|
||||
>
|
||||
{currentBoard ? 'Move' : 'Add'}
|
||||
</IAIButton>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialogOverlay>
|
||||
</AlertDialog>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(UpdateImageBoardModal);
|
@ -51,9 +51,12 @@ import { useAppToaster } from 'app/components/Toaster';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
||||
import { DeleteImageButton } from './DeleteImageModal';
|
||||
import { selectImagesById } from '../store/imagesSlice';
|
||||
import { RootState } from 'app/store/store';
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
[
|
||||
(state: RootState) => state,
|
||||
systemSelector,
|
||||
gallerySelector,
|
||||
postprocessingSelector,
|
||||
@ -61,7 +64,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
lightboxSelector,
|
||||
activeTabNameSelector,
|
||||
],
|
||||
(system, gallery, postprocessing, ui, lightbox, activeTabName) => {
|
||||
(state, system, gallery, postprocessing, ui, lightbox, activeTabName) => {
|
||||
const {
|
||||
isProcessing,
|
||||
isConnected,
|
||||
@ -81,6 +84,8 @@ const currentImageButtonsSelector = createSelector(
|
||||
shouldShowProgressInViewer,
|
||||
} = ui;
|
||||
|
||||
const imageDTO = selectImagesById(state, gallery.selectedImage ?? '');
|
||||
|
||||
const { selectedImage } = gallery;
|
||||
|
||||
return {
|
||||
@ -97,10 +102,10 @@ const currentImageButtonsSelector = createSelector(
|
||||
activeTabName,
|
||||
isLightboxOpen,
|
||||
shouldHidePreview,
|
||||
image: selectedImage,
|
||||
seed: selectedImage?.metadata?.seed,
|
||||
prompt: selectedImage?.metadata?.positive_conditioning,
|
||||
negativePrompt: selectedImage?.metadata?.negative_conditioning,
|
||||
image: imageDTO,
|
||||
seed: imageDTO?.metadata?.seed,
|
||||
prompt: imageDTO?.metadata?.positive_conditioning,
|
||||
negativePrompt: imageDTO?.metadata?.negative_conditioning,
|
||||
shouldShowProgressInViewer,
|
||||
};
|
||||
},
|
||||
|
@ -9,12 +9,12 @@ import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { configSelector } from '../../system/store/configSelectors';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { imageSelected } from '../store/gallerySlice';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { IAIImageFallback } from 'common/components/IAIImageFallback';
|
||||
import { IAIImageLoadingFallback } from 'common/components/IAIImageFallback';
|
||||
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
|
||||
export const imagesSelector = createSelector(
|
||||
[uiSelector, gallerySelector, systemSelector],
|
||||
@ -29,7 +29,7 @@ export const imagesSelector = createSelector(
|
||||
return {
|
||||
shouldShowImageDetails,
|
||||
shouldHidePreview,
|
||||
image: selectedImage,
|
||||
selectedImage,
|
||||
progressImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldAntialiasProgressImage,
|
||||
@ -45,11 +45,23 @@ export const imagesSelector = createSelector(
|
||||
const CurrentImagePreview = () => {
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
image,
|
||||
selectedImage,
|
||||
progressImage,
|
||||
shouldShowProgressInViewer,
|
||||
shouldAntialiasProgressImage,
|
||||
} = useAppSelector(imagesSelector);
|
||||
|
||||
// const image = useAppSelector((state: RootState) =>
|
||||
// selectImagesById(state, selectedImage ?? '')
|
||||
// );
|
||||
|
||||
const {
|
||||
data: image,
|
||||
isLoading,
|
||||
isError,
|
||||
isSuccess,
|
||||
} = useGetImageDTOQuery(selectedImage ?? skipToken);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleDrop = useCallback(
|
||||
@ -57,7 +69,7 @@ const CurrentImagePreview = () => {
|
||||
if (droppedImage.image_name === image?.image_name) {
|
||||
return;
|
||||
}
|
||||
dispatch(imageSelected(droppedImage));
|
||||
dispatch(imageSelected(droppedImage.image_name));
|
||||
},
|
||||
[dispatch, image?.image_name]
|
||||
);
|
||||
@ -98,14 +110,14 @@ const CurrentImagePreview = () => {
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
image={image}
|
||||
image={selectedImage && image ? image : undefined}
|
||||
onDrop={handleDrop}
|
||||
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
|
||||
fallback={<IAIImageLoadingFallback sx={{ bg: 'none' }} />}
|
||||
isUploadDisabled={true}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
{shouldShowImageDetails && image && (
|
||||
{shouldShowImageDetails && image && selectedImage && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
@ -119,7 +131,7 @@ const CurrentImagePreview = () => {
|
||||
<ImageMetadataViewer image={image} />
|
||||
</Box>
|
||||
)}
|
||||
{!shouldShowImageDetails && image && (
|
||||
{!shouldShowImageDetails && image && selectedImage && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
|
@ -2,7 +2,14 @@ import { Box, Flex, Icon, Image, MenuItem, MenuList } from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useContext, useState } from 'react';
|
||||
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
|
||||
import {
|
||||
FaCheck,
|
||||
FaExpand,
|
||||
FaFolder,
|
||||
FaImage,
|
||||
FaShare,
|
||||
FaTrash,
|
||||
} from 'react-icons/fa';
|
||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||
import {
|
||||
resizeAndScaleCanvas,
|
||||
@ -27,6 +34,8 @@ import { useAppToaster } from 'app/components/Toaster';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { useDraggable } from '@dnd-kit/core';
|
||||
import { DeleteImageContext } from 'app/contexts/DeleteImageContext';
|
||||
import { AddImageToBoardContext } from '../../../app/contexts/AddImageToBoardContext';
|
||||
import { useRemoveImageFromBoardMutation } from 'services/apiSlice';
|
||||
|
||||
export const selector = createSelector(
|
||||
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
||||
@ -62,17 +71,10 @@ interface HoverableImageProps {
|
||||
isSelected: boolean;
|
||||
}
|
||||
|
||||
const memoEqualityCheck = (
|
||||
prev: HoverableImageProps,
|
||||
next: HoverableImageProps
|
||||
) =>
|
||||
prev.image.image_name === next.image.image_name &&
|
||||
prev.isSelected === next.isSelected;
|
||||
|
||||
/**
|
||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||
*/
|
||||
const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const HoverableImage = (props: HoverableImageProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
activeTabName,
|
||||
@ -93,6 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
|
||||
|
||||
const { onDelete } = useContext(DeleteImageContext);
|
||||
const { onClickAddToBoard } = useContext(AddImageToBoardContext);
|
||||
const handleDelete = useCallback(() => {
|
||||
onDelete(image);
|
||||
}, [image, onDelete]);
|
||||
@ -106,11 +109,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
},
|
||||
});
|
||||
|
||||
const [removeFromBoard] = useRemoveImageFromBoardMutation();
|
||||
|
||||
const handleMouseOver = () => setIsHovered(true);
|
||||
const handleMouseOut = () => setIsHovered(false);
|
||||
|
||||
const handleSelectImage = useCallback(() => {
|
||||
dispatch(imageSelected(image));
|
||||
dispatch(imageSelected(image.image_name));
|
||||
}, [image, dispatch]);
|
||||
|
||||
// Recall parameters handlers
|
||||
@ -168,6 +173,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
// dispatch(setIsLightboxOpen(true));
|
||||
};
|
||||
|
||||
const handleAddToBoard = useCallback(() => {
|
||||
onClickAddToBoard(image);
|
||||
}, [image, onClickAddToBoard]);
|
||||
|
||||
const handleRemoveFromBoard = useCallback(() => {
|
||||
if (!image.board_id) {
|
||||
return;
|
||||
}
|
||||
removeFromBoard({ board_id: image.board_id, image_name: image.image_name });
|
||||
}, [image.board_id, image.image_name, removeFromBoard]);
|
||||
|
||||
const handleOpenInNewTab = () => {
|
||||
window.open(image.image_url, '_blank');
|
||||
};
|
||||
@ -244,6 +260,17 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem icon={<FaFolder />} onClickCapture={handleAddToBoard}>
|
||||
{image.board_id ? 'Change Board' : 'Add to Board'}
|
||||
</MenuItem>
|
||||
{image.board_id && (
|
||||
<MenuItem
|
||||
icon={<FaFolder />}
|
||||
onClickCapture={handleRemoveFromBoard}
|
||||
>
|
||||
Remove from Board
|
||||
</MenuItem>
|
||||
)}
|
||||
<MenuItem
|
||||
sx={{ color: 'error.300' }}
|
||||
icon={<FaTrash />}
|
||||
@ -339,8 +366,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
</ContextMenu>
|
||||
</Box>
|
||||
);
|
||||
}, memoEqualityCheck);
|
||||
};
|
||||
|
||||
HoverableImage.displayName = 'HoverableImage';
|
||||
|
||||
export default HoverableImage;
|
||||
export default memo(HoverableImage);
|
||||
|
@ -1,12 +1,15 @@
|
||||
import {
|
||||
Box,
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Grid,
|
||||
Icon,
|
||||
Text,
|
||||
VStack,
|
||||
forwardRef,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
@ -20,6 +23,7 @@ import {
|
||||
setGalleryImageObjectFit,
|
||||
setShouldAutoSwitchToNewImages,
|
||||
setShouldUseSingleGalleryColumn,
|
||||
setGalleryView,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { togglePinGalleryPanel } from 'features/ui/store/uiSlice';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
@ -53,41 +57,51 @@ import {
|
||||
selectImagesAll,
|
||||
} from '../store/imagesSlice';
|
||||
import { receivedPageOfImages } from 'services/thunks/image';
|
||||
import BoardsList from './Boards/BoardsList';
|
||||
import { boardsSelector } from '../store/boardSlice';
|
||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import { useListAllBoardsQuery } from 'services/apiSlice';
|
||||
|
||||
const categorySelector = createSelector(
|
||||
const itemSelector = createSelector(
|
||||
[(state: RootState) => state],
|
||||
(state) => {
|
||||
const { images } = state;
|
||||
const { categories } = images;
|
||||
const { categories, total: allImagesTotal, isLoading } = state.images;
|
||||
const { selectedBoardId } = state.boards;
|
||||
|
||||
const allImages = selectImagesAll(state);
|
||||
const filteredImages = allImages.filter((i) =>
|
||||
categories.includes(i.image_category)
|
||||
);
|
||||
|
||||
const images = allImages.filter((i) => {
|
||||
const isInCategory = categories.includes(i.image_category);
|
||||
const isInSelectedBoard = selectedBoardId
|
||||
? i.board_id === selectedBoardId
|
||||
: true;
|
||||
return isInCategory && isInSelectedBoard;
|
||||
});
|
||||
|
||||
return {
|
||||
images: filteredImages,
|
||||
isLoading: images.isLoading,
|
||||
areMoreImagesAvailable: filteredImages.length < images.total,
|
||||
categories: images.categories,
|
||||
images,
|
||||
allImagesTotal,
|
||||
isLoading,
|
||||
categories,
|
||||
selectedBoardId,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const mainSelector = createSelector(
|
||||
[gallerySelector, uiSelector],
|
||||
(gallery, ui) => {
|
||||
[gallerySelector, uiSelector, boardsSelector],
|
||||
(gallery, ui, boards) => {
|
||||
const {
|
||||
galleryImageMinimumWidth,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
galleryView,
|
||||
} = gallery;
|
||||
|
||||
const { shouldPinGallery } = ui;
|
||||
|
||||
return {
|
||||
shouldPinGallery,
|
||||
galleryImageMinimumWidth,
|
||||
@ -95,6 +109,8 @@ const mainSelector = createSelector(
|
||||
shouldAutoSwitchToNewImages,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
galleryView,
|
||||
selectedBoardId: boards.selectedBoardId,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
@ -126,21 +142,44 @@ const ImageGalleryContent = () => {
|
||||
shouldAutoSwitchToNewImages,
|
||||
shouldUseSingleGalleryColumn,
|
||||
selectedImage,
|
||||
galleryView,
|
||||
} = useAppSelector(mainSelector);
|
||||
|
||||
const { images, areMoreImagesAvailable, isLoading, categories } =
|
||||
useAppSelector(categorySelector);
|
||||
const { images, isLoading, allImagesTotal, categories, selectedBoardId } =
|
||||
useAppSelector(itemSelector);
|
||||
|
||||
const { selectedBoard } = useListAllBoardsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
selectedBoard: data?.find((b) => b.board_id === selectedBoardId),
|
||||
}),
|
||||
});
|
||||
|
||||
const filteredImagesTotal = useMemo(
|
||||
() => selectedBoard?.image_count ?? allImagesTotal,
|
||||
[allImagesTotal, selectedBoard?.image_count]
|
||||
);
|
||||
|
||||
const areMoreAvailable = useMemo(() => {
|
||||
return images.length < filteredImagesTotal;
|
||||
}, [images.length, filteredImagesTotal]);
|
||||
|
||||
const handleLoadMoreImages = useCallback(() => {
|
||||
dispatch(receivedPageOfImages());
|
||||
}, [dispatch]);
|
||||
dispatch(
|
||||
receivedPageOfImages({
|
||||
categories,
|
||||
boardId: selectedBoardId,
|
||||
})
|
||||
);
|
||||
}, [categories, dispatch, selectedBoardId]);
|
||||
|
||||
const handleEndReached = useMemo(() => {
|
||||
if (areMoreImagesAvailable && !isLoading) {
|
||||
if (areMoreAvailable && !isLoading) {
|
||||
return handleLoadMoreImages;
|
||||
}
|
||||
return undefined;
|
||||
}, [areMoreImagesAvailable, handleLoadMoreImages, isLoading]);
|
||||
}, [areMoreAvailable, handleLoadMoreImages, isLoading]);
|
||||
|
||||
const { isOpen: isBoardListOpen, onToggle } = useDisclosure();
|
||||
|
||||
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
||||
dispatch(setGalleryImageMinimumWidth(v));
|
||||
@ -172,33 +211,38 @@ const ImageGalleryContent = () => {
|
||||
|
||||
const handleClickImagesCategory = useCallback(() => {
|
||||
dispatch(imageCategoriesChanged(IMAGE_CATEGORIES));
|
||||
dispatch(setGalleryView('images'));
|
||||
}, [dispatch]);
|
||||
|
||||
const handleClickAssetsCategory = useCallback(() => {
|
||||
dispatch(imageCategoriesChanged(ASSETS_CATEGORIES));
|
||||
dispatch(setGalleryView('assets'));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
<VStack
|
||||
sx={{
|
||||
gap: 2,
|
||||
flexDirection: 'column',
|
||||
h: 'full',
|
||||
w: 'full',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
>
|
||||
<Box sx={{ w: 'full' }}>
|
||||
<Flex
|
||||
ref={resizeObserverRef}
|
||||
alignItems="center"
|
||||
justifyContent="space-between"
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
gap: 2,
|
||||
}}
|
||||
>
|
||||
<ButtonGroup isAttached>
|
||||
<IAIIconButton
|
||||
tooltip={t('gallery.images')}
|
||||
aria-label={t('gallery.images')}
|
||||
onClick={handleClickImagesCategory}
|
||||
isChecked={categories === IMAGE_CATEGORIES}
|
||||
isChecked={galleryView === 'images'}
|
||||
size="sm"
|
||||
icon={<FaImage />}
|
||||
/>
|
||||
@ -206,12 +250,40 @@ const ImageGalleryContent = () => {
|
||||
tooltip={t('gallery.assets')}
|
||||
aria-label={t('gallery.assets')}
|
||||
onClick={handleClickAssetsCategory}
|
||||
isChecked={categories === ASSETS_CATEGORIES}
|
||||
isChecked={galleryView === 'assets'}
|
||||
size="sm"
|
||||
icon={<FaServer />}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<Flex gap={2}>
|
||||
<Flex
|
||||
as={Button}
|
||||
onClick={onToggle}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
sx={{
|
||||
w: 'full',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
px: 2,
|
||||
_hover: {
|
||||
bg: 'base.800',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Text
|
||||
noOfLines={1}
|
||||
sx={{ w: 'full', color: 'base.200', fontWeight: 600 }}
|
||||
>
|
||||
{selectedBoard ? selectedBoard.board_name : 'All Images'}
|
||||
</Text>
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
transform: isBoardListOpen ? 'rotate(0deg)' : 'rotate(180deg)',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
<IAIPopover
|
||||
triggerComponent={
|
||||
<IAIIconButton
|
||||
@ -269,9 +341,12 @@ const ImageGalleryContent = () => {
|
||||
icon={shouldPinGallery ? <BsPinAngleFill /> : <BsPinAngle />}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex direction="column" gap={2} h="full">
|
||||
{images.length || areMoreImagesAvailable ? (
|
||||
<Box>
|
||||
<BoardsList isOpen={isBoardListOpen} />
|
||||
</Box>
|
||||
</Box>
|
||||
<Flex direction="column" gap={2} h="full" w="full">
|
||||
{images.length || areMoreAvailable ? (
|
||||
<>
|
||||
<Box ref={rootRef} data-overlayscrollbars="" h="100%">
|
||||
{shouldUseSingleGalleryColumn ? (
|
||||
@ -280,14 +355,12 @@ const ImageGalleryContent = () => {
|
||||
data={images}
|
||||
endReached={handleEndReached}
|
||||
scrollerRef={(ref) => setScrollerRef(ref)}
|
||||
itemContent={(index, image) => (
|
||||
itemContent={(index, item) => (
|
||||
<Flex sx={{ pb: 2 }}>
|
||||
<HoverableImage
|
||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||
image={image}
|
||||
isSelected={
|
||||
selectedImage?.image_name === image?.image_name
|
||||
}
|
||||
key={`${item.image_name}-${item.thumbnail_url}`}
|
||||
image={item}
|
||||
isSelected={selectedImage === item?.image_name}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
@ -302,13 +375,11 @@ const ImageGalleryContent = () => {
|
||||
List: ListContainer,
|
||||
}}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={(index, image) => (
|
||||
itemContent={(index, item) => (
|
||||
<HoverableImage
|
||||
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||
image={image}
|
||||
isSelected={
|
||||
selectedImage?.image_name === image?.image_name
|
||||
}
|
||||
key={`${item.image_name}-${item.thumbnail_url}`}
|
||||
image={item}
|
||||
isSelected={selectedImage === item?.image_name}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
@ -316,12 +387,12 @@ const ImageGalleryContent = () => {
|
||||
</Box>
|
||||
<IAIButton
|
||||
onClick={handleLoadMoreImages}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isDisabled={!areMoreAvailable}
|
||||
isLoading={isLoading}
|
||||
loadingText="Loading"
|
||||
flexShrink={0}
|
||||
>
|
||||
{areMoreImagesAvailable
|
||||
{areMoreAvailable
|
||||
? t('gallery.loadMore')
|
||||
: t('gallery.allImagesLoaded')}
|
||||
</IAIButton>
|
||||
@ -350,7 +421,7 @@ const ImageGalleryContent = () => {
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -93,19 +93,11 @@ type ImageMetadataViewerProps = {
|
||||
image: ImageDTO;
|
||||
};
|
||||
|
||||
// TODO: I don't know if this is needed.
|
||||
const memoEqualityCheck = (
|
||||
prev: ImageMetadataViewerProps,
|
||||
next: ImageMetadataViewerProps
|
||||
) => prev.image.image_name === next.image.image_name;
|
||||
|
||||
// TODO: Show more interesting information in this component.
|
||||
|
||||
/**
|
||||
* Image metadata viewer overlays currently selected image and provides
|
||||
* access to any of its metadata for use in processing.
|
||||
*/
|
||||
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const {
|
||||
recallBothPrompts,
|
||||
@ -333,8 +325,6 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}, memoEqualityCheck);
|
||||
};
|
||||
|
||||
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
|
||||
|
||||
export default ImageMetadataViewer;
|
||||
export default memo(ImageMetadataViewer);
|
||||
|
@ -42,7 +42,7 @@ export const nextPrevImageButtonsSelector = createSelector(
|
||||
}
|
||||
|
||||
const currentImageIndex = filteredImageIds.findIndex(
|
||||
(i) => i === selectedImage.image_name
|
||||
(i) => i === selectedImage
|
||||
);
|
||||
|
||||
const nextImageIndex = clamp(
|
||||
@ -71,6 +71,8 @@ export const nextPrevImageButtonsSelector = createSelector(
|
||||
!isNaN(currentImageIndex) && currentImageIndex === imagesLength - 1,
|
||||
nextImage,
|
||||
prevImage,
|
||||
nextImageId,
|
||||
prevImageId,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -84,7 +86,7 @@ const NextPrevImageButtons = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { isOnFirstImage, isOnLastImage, nextImage, prevImage } =
|
||||
const { isOnFirstImage, isOnLastImage, nextImageId, prevImageId } =
|
||||
useAppSelector(nextPrevImageButtonsSelector);
|
||||
|
||||
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] =
|
||||
@ -99,19 +101,19 @@ const NextPrevImageButtons = () => {
|
||||
}, []);
|
||||
|
||||
const handlePrevImage = useCallback(() => {
|
||||
dispatch(imageSelected(prevImage));
|
||||
}, [dispatch, prevImage]);
|
||||
dispatch(imageSelected(prevImageId));
|
||||
}, [dispatch, prevImageId]);
|
||||
|
||||
const handleNextImage = useCallback(() => {
|
||||
dispatch(imageSelected(nextImage));
|
||||
}, [dispatch, nextImage]);
|
||||
dispatch(imageSelected(nextImageId));
|
||||
}, [dispatch, nextImageId]);
|
||||
|
||||
useHotkeys(
|
||||
'left',
|
||||
() => {
|
||||
handlePrevImage();
|
||||
},
|
||||
[prevImage]
|
||||
[prevImageId]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@ -119,7 +121,7 @@ const NextPrevImageButtons = () => {
|
||||
() => {
|
||||
handleNextImage();
|
||||
},
|
||||
[nextImage]
|
||||
[nextImageId]
|
||||
);
|
||||
|
||||
return (
|
||||
|
@ -0,0 +1,26 @@
|
||||
import { motion } from 'framer-motion';
|
||||
|
||||
export const SelectedItemOverlay = () => (
|
||||
<motion.div
|
||||
initial={{
|
||||
opacity: 0,
|
||||
}}
|
||||
animate={{
|
||||
opacity: 1,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
exit={{
|
||||
opacity: 0,
|
||||
transition: { duration: 0.1 },
|
||||
}}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineStart: 0,
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
boxShadow: 'inset 0px 0px 0px 2px var(--invokeai-colors-accent-300)',
|
||||
borderRadius: 'var(--invokeai-radii-base)',
|
||||
}}
|
||||
/>
|
||||
);
|
@ -0,0 +1,23 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { selectBoardsAll } from './boardSlice';
|
||||
|
||||
export const boardSelector = (state: RootState) => state.boards.entities;
|
||||
|
||||
export const searchBoardsSelector = createSelector(
|
||||
(state: RootState) => state,
|
||||
(state) => {
|
||||
const {
|
||||
boards: { searchText },
|
||||
} = state;
|
||||
|
||||
if (!searchText) {
|
||||
// If no search text provided, return all entities
|
||||
return selectBoardsAll(state);
|
||||
}
|
||||
|
||||
return selectBoardsAll(state).filter((i) =>
|
||||
i.board_name.toLowerCase().includes(searchText.toLowerCase())
|
||||
);
|
||||
}
|
||||
);
|
@ -0,0 +1,47 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { api } from 'services/apiSlice';
|
||||
|
||||
type BoardsState = {
|
||||
searchText: string;
|
||||
selectedBoardId?: string;
|
||||
updateBoardModalOpen: boolean;
|
||||
};
|
||||
|
||||
export const initialBoardsState: BoardsState = {
|
||||
updateBoardModalOpen: false,
|
||||
searchText: '',
|
||||
};
|
||||
|
||||
const boardsSlice = createSlice({
|
||||
name: 'boards',
|
||||
initialState: initialBoardsState,
|
||||
reducers: {
|
||||
boardIdSelected: (state, action: PayloadAction<string | undefined>) => {
|
||||
state.selectedBoardId = action.payload;
|
||||
},
|
||||
setBoardSearchText: (state, action: PayloadAction<string>) => {
|
||||
state.searchText = action.payload;
|
||||
},
|
||||
setUpdateBoardModalOpen: (state, action: PayloadAction<boolean>) => {
|
||||
state.updateBoardModalOpen = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(
|
||||
api.endpoints.deleteBoard.matchFulfilled,
|
||||
(state, action) => {
|
||||
if (action.meta.arg.originalArgs === state.selectedBoardId) {
|
||||
state.selectedBoardId = undefined;
|
||||
}
|
||||
}
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
export const { boardIdSelected, setBoardSearchText, setUpdateBoardModalOpen } =
|
||||
boardsSlice.actions;
|
||||
|
||||
export const boardsSelector = (state: RootState) => state.boards;
|
||||
|
||||
export default boardsSlice.reducer;
|
@ -1,17 +1,16 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { imageUpserted } from './imagesSlice';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
|
||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||
|
||||
export interface GalleryState {
|
||||
selectedImage?: ImageDTO;
|
||||
selectedImage?: string;
|
||||
galleryImageMinimumWidth: number;
|
||||
galleryImageObjectFit: GalleryImageObjectFitType;
|
||||
shouldAutoSwitchToNewImages: boolean;
|
||||
shouldUseSingleGalleryColumn: boolean;
|
||||
galleryView: 'images' | 'assets' | 'boards';
|
||||
}
|
||||
|
||||
export const initialGalleryState: GalleryState = {
|
||||
@ -19,13 +18,14 @@ export const initialGalleryState: GalleryState = {
|
||||
galleryImageObjectFit: 'cover',
|
||||
shouldAutoSwitchToNewImages: true,
|
||||
shouldUseSingleGalleryColumn: false,
|
||||
galleryView: 'images',
|
||||
};
|
||||
|
||||
export const gallerySlice = createSlice({
|
||||
name: 'gallery',
|
||||
initialState: initialGalleryState,
|
||||
reducers: {
|
||||
imageSelected: (state, action: PayloadAction<ImageDTO | undefined>) => {
|
||||
imageSelected: (state, action: PayloadAction<string | undefined>) => {
|
||||
state.selectedImage = action.payload;
|
||||
// TODO: if the user selects an image, disable the auto switch?
|
||||
// state.shouldAutoSwitchToNewImages = false;
|
||||
@ -48,6 +48,12 @@ export const gallerySlice = createSlice({
|
||||
) => {
|
||||
state.shouldUseSingleGalleryColumn = action.payload;
|
||||
},
|
||||
setGalleryView: (
|
||||
state,
|
||||
action: PayloadAction<'images' | 'assets' | 'boards'>
|
||||
) => {
|
||||
state.galleryView = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(imageUpserted, (state, action) => {
|
||||
@ -55,17 +61,17 @@ export const gallerySlice = createSlice({
|
||||
state.shouldAutoSwitchToNewImages &&
|
||||
action.payload.image_category === 'general'
|
||||
) {
|
||||
state.selectedImage = action.payload;
|
||||
state.selectedImage = action.payload.image_name;
|
||||
}
|
||||
});
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
// builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
// const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
if (state.selectedImage?.image_name === image_name) {
|
||||
state.selectedImage.image_url = image_url;
|
||||
state.selectedImage.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
});
|
||||
// if (state.selectedImage?.image_name === image_name) {
|
||||
// state.selectedImage.image_url = image_url;
|
||||
// state.selectedImage.thumbnail_url = thumbnail_url;
|
||||
// }
|
||||
// });
|
||||
},
|
||||
});
|
||||
|
||||
@ -75,6 +81,7 @@ export const {
|
||||
setGalleryImageObjectFit,
|
||||
setShouldAutoSwitchToNewImages,
|
||||
setShouldUseSingleGalleryColumn,
|
||||
setGalleryView,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
export default gallerySlice.reducer;
|
||||
|
@ -11,7 +11,6 @@ import { dateComparator } from 'common/util/dateComparator';
|
||||
import { keyBy } from 'lodash-es';
|
||||
import {
|
||||
imageDeleted,
|
||||
imageMetadataReceived,
|
||||
imageUrlsReceived,
|
||||
receivedPageOfImages,
|
||||
} from 'services/thunks/image';
|
||||
@ -74,11 +73,21 @@ const imagesSlice = createSlice({
|
||||
});
|
||||
builder.addCase(receivedPageOfImages.fulfilled, (state, action) => {
|
||||
state.isLoading = false;
|
||||
const { boardId, categories, imageOrigin, isIntermediate } =
|
||||
action.meta.arg;
|
||||
|
||||
const { items, offset, limit, total } = action.payload;
|
||||
imagesAdapter.upsertMany(state, items);
|
||||
|
||||
if (!categories?.includes('general') || boardId) {
|
||||
// need to skip updating the total images count if the images recieved were for a specific board
|
||||
// TODO: this doesn't work when on the Asset tab/category...
|
||||
return;
|
||||
}
|
||||
|
||||
state.offset = offset;
|
||||
state.limit = limit;
|
||||
state.total = total;
|
||||
imagesAdapter.upsertMany(state, items);
|
||||
});
|
||||
builder.addCase(imageDeleted.pending, (state, action) => {
|
||||
// Image deleted
|
||||
@ -154,3 +163,16 @@ export const selectFilteredImagesIds = createSelector(
|
||||
.map((i) => i.image_name);
|
||||
}
|
||||
);
|
||||
|
||||
// export const selectImageById = createSelector(
|
||||
// (state: RootState, imageId) => state,
|
||||
// (state) => {
|
||||
// const {
|
||||
// images: { categories },
|
||||
// } = state;
|
||||
|
||||
// return selectImagesAll(state)
|
||||
// .filter((i) => categories.includes(i.image_category))
|
||||
// .map((i) => i.image_name);
|
||||
// }
|
||||
// );
|
||||
|
@ -11,6 +11,8 @@ import { FieldComponentProps } from './types';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { useGetImageDTOQuery } from 'services/apiSlice';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
|
||||
const ImageInputFieldComponent = (
|
||||
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
||||
@ -19,9 +21,16 @@ const ImageInputFieldComponent = (
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const {
|
||||
data: image,
|
||||
isLoading,
|
||||
isError,
|
||||
isSuccess,
|
||||
} = useGetImageDTOQuery(field.value ?? skipToken);
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(droppedImage: ImageDTO) => {
|
||||
if (field.value?.image_name === droppedImage.image_name) {
|
||||
if (field.value === droppedImage.image_name) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -29,11 +38,11 @@ const ImageInputFieldComponent = (
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: droppedImage,
|
||||
value: droppedImage.image_name,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, field.value?.image_name, nodeId]
|
||||
[dispatch, field.name, field.value, nodeId]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
@ -56,7 +65,7 @@ const ImageInputFieldComponent = (
|
||||
}}
|
||||
>
|
||||
<IAIDndImage
|
||||
image={field.value}
|
||||
image={image}
|
||||
onDrop={handleDrop}
|
||||
onReset={handleReset}
|
||||
resetIconSize="sm"
|
||||
|
@ -1,28 +1,18 @@
|
||||
import { Select } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ModelInputFieldTemplate,
|
||||
ModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { selectModelsIds } from 'features/system/store/modelSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { ChangeEvent, memo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const availableModelsSelector = createSelector(
|
||||
[selectModelsIds],
|
||||
(allModelNames) => {
|
||||
return { allModelNames };
|
||||
// return map(modelList, (_, name) => name);
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListModelsQuery } from 'services/apiSlice';
|
||||
|
||||
const ModelInputFieldComponent = (
|
||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||
@ -30,28 +20,82 @@ const ModelInputFieldComponent = (
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { allModelNames } = useAppSelector(availableModelsSelector);
|
||||
const { data: pipelineModels } = useListModelsQuery({
|
||||
model_type: 'pipeline',
|
||||
});
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!pipelineModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(pipelineModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [pipelineModels]);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => pipelineModels?.entities[field.value ?? pipelineModels.ids[0]],
|
||||
[pipelineModels?.entities, pipelineModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: e.target.value,
|
||||
value: v,
|
||||
})
|
||||
);
|
||||
};
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && pipelineModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModel = pipelineModels?.ids[0];
|
||||
|
||||
if (!isString(firstModel)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleValueChanged(firstModel);
|
||||
}, [field.value, handleValueChanged, pipelineModels?.ids]);
|
||||
|
||||
return (
|
||||
<Select
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model &&
|
||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
value={field.value || allModelNames[0]}
|
||||
>
|
||||
{allModelNames.map((option) => (
|
||||
<option key={option}>{option}</option>
|
||||
))}
|
||||
</Select>
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -101,21 +101,6 @@ const nodesSlice = createSlice({
|
||||
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
||||
state.schema = action.payload;
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_url, thumbnail_url } = action.payload;
|
||||
|
||||
state.nodes.forEach((node) => {
|
||||
forEach(node.data.inputs, (input) => {
|
||||
if (input.type === 'image') {
|
||||
if (input.value?.image_name === image_name) {
|
||||
input.value.image_url = image_url;
|
||||
input.value.thumbnail_url = thumbnail_url;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -214,7 +214,7 @@ export type VaeInputFieldValue = FieldValueBase & {
|
||||
|
||||
export type ImageInputFieldValue = FieldValueBase & {
|
||||
type: 'image';
|
||||
value?: ImageDTO;
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type ModelInputFieldValue = FieldValueBase & {
|
||||
|
@ -2,8 +2,7 @@ import { RootState } from 'app/store/store';
|
||||
import { filter, forEach, size } from 'lodash-es';
|
||||
import { CollectInvocation, ControlNetInvocation } from 'services/api';
|
||||
import { NonNullableGraph } from '../types/types';
|
||||
|
||||
const CONTROL_NET_COLLECT = 'control_net_collect';
|
||||
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
|
||||
|
||||
export const addControlNetToLinearGraph = (
|
||||
graph: NonNullableGraph,
|
||||
@ -37,7 +36,7 @@ export const addControlNetToLinearGraph = (
|
||||
});
|
||||
}
|
||||
|
||||
forEach(controlNets, (controlNet, index) => {
|
||||
forEach(controlNets, (controlNet) => {
|
||||
const {
|
||||
controlNetId,
|
||||
isEnabled,
|
||||
@ -66,15 +65,13 @@ export const addControlNetToLinearGraph = (
|
||||
|
||||
if (processedControlImage && processorType !== 'none') {
|
||||
// We've already processed the image in the app, so we can just use the processed image
|
||||
const { image_name } = processedControlImage;
|
||||
controlNetNode.image = {
|
||||
image_name,
|
||||
image_name: processedControlImage,
|
||||
};
|
||||
} else if (controlImage) {
|
||||
// The control image is preprocessed
|
||||
const { image_name } = controlImage;
|
||||
controlNetNode.image = {
|
||||
image_name,
|
||||
image_name: controlImage,
|
||||
};
|
||||
} else {
|
||||
// Skip ControlNets without an unprocessed image - should never happen if everything is working correctly
|
||||
|
@ -1,116 +1,39 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import {
|
||||
Edge,
|
||||
ImageToImageInvocation,
|
||||
InpaintInvocation,
|
||||
IterateInvocation,
|
||||
RandomRangeInvocation,
|
||||
RangeInvocation,
|
||||
TextToImageInvocation,
|
||||
} from 'services/api';
|
||||
import { buildImg2ImgNode } from '../nodeBuilders/buildImageToImageNode';
|
||||
import { buildTxt2ImgNode } from '../nodeBuilders/buildTextToImageNode';
|
||||
import { buildRangeNode } from '../nodeBuilders/buildRangeNode';
|
||||
import { buildIterateNode } from '../nodeBuilders/buildIterateNode';
|
||||
import { buildEdges } from '../edgeBuilders/buildEdges';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { log } from 'app/logging/useLogger';
|
||||
import { buildInpaintNode } from '../nodeBuilders/buildInpaintNode';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph';
|
||||
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
|
||||
|
||||
const moduleLog = log.child({ namespace: 'nodes' });
|
||||
|
||||
const buildBaseNode = (
|
||||
nodeType: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
||||
state: RootState
|
||||
):
|
||||
| TextToImageInvocation
|
||||
| ImageToImageInvocation
|
||||
| InpaintInvocation
|
||||
| undefined => {
|
||||
const overrides = {
|
||||
...state.canvas.boundingBoxDimensions,
|
||||
is_intermediate: true,
|
||||
};
|
||||
|
||||
if (nodeType === 'txt2img') {
|
||||
return buildTxt2ImgNode(state, overrides);
|
||||
}
|
||||
|
||||
if (nodeType === 'img2img') {
|
||||
return buildImg2ImgNode(state, overrides);
|
||||
}
|
||||
|
||||
if (nodeType === 'inpaint' || nodeType === 'outpaint') {
|
||||
return buildInpaintNode(state, overrides);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds the Canvas workflow graph and image blobs.
|
||||
*/
|
||||
export const buildCanvasGraphComponents = async (
|
||||
export const buildCanvasGraph = (
|
||||
state: RootState,
|
||||
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint'
|
||||
): Promise<
|
||||
| {
|
||||
rangeNode: RangeInvocation | RandomRangeInvocation;
|
||||
iterateNode: IterateInvocation;
|
||||
baseNode:
|
||||
| TextToImageInvocation
|
||||
| ImageToImageInvocation
|
||||
| InpaintInvocation;
|
||||
edges: Edge[];
|
||||
}
|
||||
| undefined
|
||||
> => {
|
||||
// The base node is a txt2img, img2img or inpaint node
|
||||
const baseNode = buildBaseNode(generationMode, state);
|
||||
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
||||
canvasInitImage: ImageDTO | undefined,
|
||||
canvasMaskImage: ImageDTO | undefined
|
||||
) => {
|
||||
let graph: NonNullableGraph;
|
||||
|
||||
if (!baseNode) {
|
||||
moduleLog.error('Problem building base node');
|
||||
return;
|
||||
if (generationMode === 'txt2img') {
|
||||
graph = buildCanvasTextToImageGraph(state);
|
||||
} else if (generationMode === 'img2img') {
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
} else {
|
||||
if (!canvasInitImage || !canvasMaskImage) {
|
||||
throw new Error('Missing canvas init and mask images');
|
||||
}
|
||||
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
}
|
||||
|
||||
if (baseNode.type === 'inpaint') {
|
||||
const {
|
||||
seamSize,
|
||||
seamBlur,
|
||||
seamSteps,
|
||||
seamStrength,
|
||||
tileSize,
|
||||
infillMethod,
|
||||
} = state.generation;
|
||||
forEach(graph.nodes, (node) => {
|
||||
graph.nodes[node.id].is_intermediate = true;
|
||||
});
|
||||
|
||||
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } =
|
||||
state.canvas;
|
||||
|
||||
if (boundingBoxScaleMethod !== 'none') {
|
||||
baseNode.inpaint_width = scaledBoundingBoxDimensions.width;
|
||||
baseNode.inpaint_height = scaledBoundingBoxDimensions.height;
|
||||
}
|
||||
|
||||
baseNode.seam_size = seamSize;
|
||||
baseNode.seam_blur = seamBlur;
|
||||
baseNode.seam_strength = seamStrength;
|
||||
baseNode.seam_steps = seamSteps;
|
||||
baseNode.infill_method = infillMethod as InpaintInvocation['infill_method'];
|
||||
|
||||
if (infillMethod === 'tile') {
|
||||
baseNode.tile_size = tileSize;
|
||||
}
|
||||
}
|
||||
|
||||
// We always range and iterate nodes, no matter the iteration count
|
||||
// This is required to provide the correct seeds to the backend engine
|
||||
const rangeNode = buildRangeNode(state);
|
||||
const iterateNode = buildIterateNode();
|
||||
|
||||
// Build the edges for the nodes selected.
|
||||
const edges = buildEdges(baseNode, rangeNode, iterateNode);
|
||||
|
||||
return {
|
||||
rangeNode,
|
||||
iterateNode,
|
||||
baseNode,
|
||||
edges,
|
||||
};
|
||||
return graph;
|
||||
};
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user