Merge branch 'main' into bugfix/fp16-models

This commit is contained in:
StAlKeR7779 2023-08-05 01:42:43 +03:00 committed by GitHub
commit 9bacd77a79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
148 changed files with 3697 additions and 2813 deletions

View File

@ -1,13 +1,14 @@
name: Black # TODO: add isort and flake8 later
name: style checks
# just formatting for now
# TODO: add isort and flake8 later
on:
pull_request: {}
pull_request:
push:
branches: master
tags: "*"
branches: main
jobs:
test:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
@ -19,8 +20,7 @@ jobs:
- name: Install dependencies with pip
run: |
pip install --upgrade pip wheel
pip install .[test]
pip install black
# - run: isort --check-only .
- run: black --check .

View File

@ -1,50 +0,0 @@
name: Test invoke.py pip
# This is a dummy stand-in for the actual tests
# we don't need to run python tests on non-Python changes
# But PRs require passing tests to be mergeable
on:
pull_request:
paths:
- '**'
- '!pyproject.toml'
- '!invokeai/**'
- '!tests/**'
- 'invokeai/frontend/web/**'
merge_group:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
if: github.event.pull_request.draft == false
strategy:
matrix:
python-version:
- '3.10'
pytorch:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
- pytorch: linux-cpu
os: ubuntu-22.04
- pytorch: macos-default
os: macOS-12
- pytorch: windows-cpu
os: windows-2022
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
steps:
- name: skip
run: echo "no build required"

View File

@ -3,16 +3,7 @@ on:
push:
branches:
- 'main'
paths:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
pull_request:
paths:
- 'pyproject.toml'
- 'invokeai/**'
- 'tests/**'
- '!invokeai/frontend/web/**'
types:
- 'ready_for_review'
- 'opened'
@ -65,10 +56,23 @@ jobs:
id: checkout-sources
uses: actions/checkout@v3
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v37
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: set test prompt to main branch validation
if: steps.changed-files.outputs.python_any_changed == 'true'
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
if: steps.changed-files.outputs.python_any_changed == 'true'
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
@ -76,6 +80,7 @@ jobs:
cache-dependency-path: pyproject.toml
- name: install invokeai
if: steps.changed-files.outputs.python_any_changed == 'true'
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
@ -83,6 +88,7 @@ jobs:
--editable=".[test]"
- name: run pytest
if: steps.changed-files.outputs.python_any_changed == 'true'
id: run-pytest
run: pytest

View File

@ -2,7 +2,6 @@
from typing import Optional
from logging import Logger
import os
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
@ -30,6 +29,7 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService
from ..services.invocation_stats import InvocationStatsService
from .events import FastAPIEventService
@ -128,6 +128,7 @@ class ApiDependencies:
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
)

View File

@ -1,24 +1,30 @@
from fastapi import Body, HTTPException, Path, Query
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO
from pydantic import BaseModel, Field
from ..dependencies import ApiDependencies
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
class AddImagesToBoardResult(BaseModel):
board_id: str = Field(description="The id of the board the images were added to")
added_image_names: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(BaseModel):
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
@board_images_router.post(
"/",
operation_id="create_board_image",
operation_id="add_image_to_board",
responses={
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
)
async def create_board_image(
async def add_image_to_board(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
):
@ -29,26 +35,78 @@ async def create_board_image(
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add to board")
raise HTTPException(status_code=500, detail="Failed to add image to board")
@board_images_router.delete(
"/",
operation_id="remove_board_image",
operation_id="remove_image_from_board",
responses={
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
)
async def remove_board_image(
board_id: str = Body(description="The id of the board"),
image_name: str = Body(description="The name of the image to remove"),
async def remove_image_from_board(
image_name: str = Body(description="The name of the image to remove", embed=True),
):
"""Deletes a board_image"""
"""Removes an image from its board, if it had one"""
try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(
board_id=board_id, image_name=image_name
)
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@board_images_router.post(
"/batch",
operation_id="add_images_to_board",
responses={
201: {"description": "Images were added to board successfully"},
},
status_code=201,
response_model=AddImagesToBoardResult,
)
async def add_images_to_board(
board_id: str = Body(description="The id of the board to add to"),
image_names: list[str] = Body(description="The names of the images to add", embed=True),
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
try:
added_image_names: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id, image_name=image_name
)
added_image_names.append(image_name)
except:
pass
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@board_images_router.post(
"/batch/delete",
operation_id="remove_images_from_board",
responses={
201: {"description": "Images were removed from board successfully"},
},
status_code=201,
response_model=RemoveImagesFromBoardResult,
)
async def remove_images_from_board(
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
try:
removed_image_names: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_image_names.append(image_name)
except:
pass
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@ -5,6 +5,7 @@ from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadF
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, Field
from invokeai.app.invocations.metadata import ImageMetadata
from invokeai.app.models.image import ImageCategory, ResourceOrigin
@ -25,7 +26,7 @@ IMAGE_MAX_AGE = 31536000
@images_router.post(
"/",
"/upload",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully"},
@ -77,7 +78,7 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_name}", operation_id="delete_image")
@images_router.delete("/i/{image_name}", operation_id="delete_image")
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),
) -> None:
@ -103,7 +104,7 @@ async def clear_intermediates() -> int:
@images_router.patch(
"/{image_name}",
"/i/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
@ -120,7 +121,7 @@ async def update_image(
@images_router.get(
"/{image_name}",
"/i/{image_name}",
operation_id="get_image_dto",
response_model=ImageDTO,
)
@ -136,7 +137,7 @@ async def get_image_dto(
@images_router.get(
"/{image_name}/metadata",
"/i/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageMetadata,
)
@ -152,7 +153,7 @@ async def get_image_metadata(
@images_router.get(
"/{image_name}/full",
"/i/{image_name}/full",
operation_id="get_image_full",
response_class=Response,
responses={
@ -187,7 +188,7 @@ async def get_image_full(
@images_router.get(
"/{image_name}/thumbnail",
"/i/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
@ -216,7 +217,7 @@ async def get_image_thumbnail(
@images_router.get(
"/{image_name}/urls",
"/i/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
@ -265,3 +266,24 @@ async def list_image_dtos(
)
return image_dtos
class DeleteImagesFromListResult(BaseModel):
deleted_images: list[str]
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
async def delete_images_from_list(
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesFromListResult:
try:
deleted_images: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.append(image_name)
except:
pass
return DeleteImagesFromListResult(deleted_images=deleted_images)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to delete images")

View File

@ -37,6 +37,7 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.invocation_stats import InvocationStatsService
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -311,6 +312,7 @@ def invoke_cli():
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
configuration=config,
)

View File

@ -109,12 +109,15 @@ class CompelInvocation(BaseInvocation):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
).context.model,
)
)
except ModelNotFoundException:
# print(e)
@ -173,7 +176,7 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
@ -197,12 +200,15 @@ class SDXLPromptInvocationBase:
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
).context.model,
)
)
except ModelNotFoundException:
# print(e)
@ -210,8 +216,8 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
@ -247,7 +253,7 @@ class SDXLPromptInvocationBase:
return c, c_pooled, None
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
@ -271,12 +277,15 @@ class SDXLPromptInvocationBase:
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
).context.model,
)
)
except ModelNotFoundException:
# print(e)
@ -284,8 +293,8 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
@ -357,11 +366,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@ -415,7 +424,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@ -467,11 +477,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
else:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@ -525,7 +535,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)

View File

@ -14,7 +14,7 @@ from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,

View File

@ -1,6 +1,6 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from pydantic import Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -10,16 +10,17 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class LoRAMetadataField(BaseModel):
class LoRAMetadataField(BaseModelExcludeNull):
"""LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model")
class CoreMetadata(BaseModel):
class CoreMetadata(BaseModelExcludeNull):
"""Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(
@ -70,7 +71,7 @@ class CoreMetadata(BaseModel):
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
class ImageMetadata(BaseModel):
class ImageMetadata(BaseModelExcludeNull):
"""An image's generation metadata"""
metadata: Optional[dict] = Field(

View File

@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation):
return output
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
# fmt: on
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unknown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip2')
output = SDXLLoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
return output
class VAEModelField(BaseModel):
"""Vae model field"""

View File

@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation):
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.clip.loras
@ -76,18 +75,14 @@ class ONNXPromptInvocation(BaseInvocation):
name = trigger[1:-1]
try:
ti_list.append(
# stack.enter_context(
# context.services.model_manager.get_model(
# model_name=name,
# base_model=self.clip.text_encoder.base_model,
# model_type=ModelType.TextualInversion,
# )
# )
(
name,
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
).context.model,
)
)
except Exception:
# print(e)

View File

@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
@ -293,10 +293,20 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
num_inference_steps = self.steps
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps
@ -543,9 +553,19 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
context=context,
)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps, device=unet.device)

View File

@ -25,7 +25,6 @@ class BoardImageRecordStorageBase(ABC):
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
@ -154,7 +153,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
@ -162,9 +160,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE board_id = ? AND image_name = ?;
WHERE image_name = ?;
""",
(board_id, image_name),
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:

View File

@ -31,7 +31,6 @@ class BoardImagesServiceABC(ABC):
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
@ -93,10 +92,9 @@ class BoardImagesService(BoardImagesServiceABC):
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.remove_image_from_board(board_id, image_name)
self._services.board_image_records.remove_image_from_board(image_name)
def get_all_board_image_names_for_board(
self,

View File

@ -289,9 +289,10 @@ class ImageService(ImageServiceABC):
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try:
image_record = self._services.image_records.get(image_name)
metadata = self._services.image_records.get_metadata(image_name)
if not image_record.session_id:
return ImageMetadata()
return ImageMetadata(metadata=metadata)
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
graph = None
@ -303,7 +304,6 @@ class ImageService(ImageServiceABC):
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
metadata = self._services.image_records.get_metadata(image_name)
return ImageMetadata(graph=graph, metadata=metadata)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")

View File

@ -32,6 +32,7 @@ class InvocationServices:
logger: "Logger"
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
def __init__(
@ -47,6 +48,7 @@ class InvocationServices:
logger: "Logger",
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
):
self.board_images = board_images
@ -61,4 +63,5 @@ class InvocationServices:
self.logger = logger
self.model_manager = model_manager
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@ -0,0 +1,223 @@
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
"""
Usage:
statistics = InvocationStatsService(graph_execution_manager)
with statistics.collect_stats(invocation, graph_execution_state.id):
... execute graphs...
statistics.log_stats()
Typical output:
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
writes to the system log is stored in InvocationServices.performance_statistics.
"""
import time
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from typing import Dict
import torch
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
"""
Initialize the InvocationStatsService and reset counters to zero
:param graph_execution_manager: Graph execution manager for this session
"""
pass
@abstractmethod
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> AbstractContextManager:
"""
Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
pass
@abstractmethod
def reset_stats(self, graph_execution_state_id: str):
"""
Reset all statistics for the indicated graph
:param graph_execution_state_id
"""
pass
@abstractmethod
def reset_all_stats(self):
"""Zero all statistics"""
pass
@abstractmethod
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
"""
pass
@abstractmethod
def log_stats(self):
"""
Write out the accumulated statistics to the log or somewhere else.
"""
pass
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems"""
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
class StatsContext:
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def __exit__(self, *args):
self.collector.update_invocation_stats(
self.graph_id,
self.invocation.type,
time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
return self.StatsContext(invocation, graph_execution_state_id, self)
def reset_all_stats(self):
"""Zero all statistics"""
self._stats = {}
def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try:
self._stats.pop(graph_execution_id)
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
stats.calls += 1
stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed if when the execution of the graph
is complete.
"""
completed = set()
for graph_id, node_log in self._stats.items():
current_graph_state = self.graph_execution_manager.get(graph_id)
if not current_graph_state.is_complete():
continue
total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info("Node Calls Seconds VRAM Used")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
total_time += stats.time_used
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
if torch.cuda.is_available():
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
completed.add(graph_id)
for graph_id in completed:
del self._stats[graph_id]

View File

@ -0,0 +1,8 @@
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class BoardImage(BaseModelExcludeNull):
board_id: str = Field(description="The id of the board")
image_name: str = Field(description="The name of the image")

View File

@ -1,10 +1,11 @@
from typing import Optional, Union
from datetime import datetime
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from pydantic import Field
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class BoardRecord(BaseModel):
class BoardRecord(BaseModelExcludeNull):
"""Deserialized board record."""
board_id: str = Field(description="The unique ID of the board.")

View File

@ -1,13 +1,14 @@
import datetime
from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from pydantic import Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class ImageRecord(BaseModel):
class ImageRecord(BaseModelExcludeNull):
"""Deserialized image record without metadata."""
image_name: str = Field(description="The unique name of the image.")
@ -40,7 +41,7 @@ class ImageRecord(BaseModel):
"""The node ID that generated this image, if it is a generated image."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
"""A set of changes to apply to an image record.
Only limited changes are valid:
@ -60,7 +61,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
"""The image's new `is_intermediate` flag."""
class ImageUrlsDTO(BaseModel):
class ImageUrlsDTO(BaseModelExcludeNull):
"""The URLs for an image and its thumbnail."""
image_name: str = Field(description="The unique name of the image.")
@ -76,11 +77,15 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
"""The id of the board the image belongs to, if one exists."""
pass
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
image_record: ImageRecord,
image_url: str,
thumbnail_url: str,
board_id: Optional[str],
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(

View File

@ -1,14 +1,15 @@
import time
import traceback
from threading import Event, Thread, BoundedSemaphore
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
from threading import BoundedSemaphore, Event, Thread
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import InvocationContext
from ..models.exceptions import CanceledException
from .invocation_queue import InvocationQueueItem
from .invocation_stats import InvocationStatsServiceBase
from .invoker import InvocationProcessorABC, Invoker
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
@ -35,6 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
while not stop_event.is_set():
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
@ -83,6 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
with statistics.collect_stats(invocation, graph_execution_state.id):
outputs = invocation.invoke(
InvocationContext(
services=self.__invoker.services,
@ -107,11 +111,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
source_node_id=source_node_id,
result=outputs.dict(),
)
statistics.log_stats()
except KeyboardInterrupt:
pass
except CanceledException:
statistics.reset_stats(graph_execution_state.id)
pass
except Exception as e:
@ -133,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__,
error=error,
)
statistics.reset_stats(graph_execution_state.id)
pass
# Check queue to see if this is canceled, and skip if so

View File

@ -20,6 +20,6 @@ class LocalUrlService(UrlServiceBase):
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_basename}/full"
return f"{self._base_url}/images/i/{image_basename}/full"

View File

@ -0,0 +1,23 @@
from typing import Any
from pydantic import BaseModel
"""
We want to exclude null values from objects that make their way to the client.
Unfortunately there is no built-in way to do this in pydantic, so we need to override the default
dict method to do this.
From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154541
"""
class BaseModelExcludeNull(BaseModel):
def dict(self, *args, **kwargs) -> dict[str, Any]:
"""
Override the default dict method to exclude None values in the response
"""
kwargs.pop("exclude_none", None)
return super().dict(*args, exclude_none=True, **kwargs)
pass

View File

@ -305,7 +305,7 @@ class ModelInstall(object):
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if "model_index.json" in files and "unet/model.onnx" not in files:
if "model_index.json" in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif "unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging)

View File

@ -13,3 +13,4 @@ from .models import (
DuplicateModelException,
)
from .model_merge import ModelMerger, MergeInterpolationMethod
from .lora import ModelPatcher

View File

@ -20,424 +20,6 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
# TODO: rename and split this file
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoRAModel: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> LoRAModel:
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
else:
# TODO: diff/ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
return
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
"""
loras = [
(lora_model1, 0.7),
@ -516,6 +98,26 @@ class ModelPatcher:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@classmethod
@contextmanager
def apply_lora(
@ -562,7 +164,7 @@ class ModelPatcher:
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Any],
ti_list: List[Tuple[str, Any]],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
init_tokens_count = None
new_tokens_added = None
@ -572,27 +174,27 @@ class ModelPatcher:
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
def _get_trigger(ti, index):
trigger = ti.name
def _get_trigger(ti_name, index):
trigger = ti_name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti in ti_list:
for ti_name, ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
model_embeddings = text_encoder.get_input_embeddings()
for ti in ti_list:
for ti_name, ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i]
trigger = _get_trigger(ti, i)
trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
@ -637,7 +239,6 @@ class ModelPatcher:
class TextualInversionModel:
name: str
embedding: torch.Tensor # [n, 768]|[n, 1280]
@classmethod
@ -651,7 +252,6 @@ class TextualInversionModel:
file_path = Path(file_path)
result = cls() # TODO:
result.name = file_path.stem # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
@ -828,7 +428,7 @@ class ONNXModelPatcher:
cls,
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Any],
ti_list: List[Tuple[str, Any]],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
from .models.base import IAIOnnxRuntimeModel
@ -841,17 +441,17 @@ class ONNXModelPatcher:
ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti, index):
trigger = ti.name
def _get_trigger(ti_name, index):
trigger = ti_name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti in ti_list:
for ti_name, ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
@ -861,11 +461,11 @@ class ONNXModelPatcher:
axis=0,
)
for ti in ti_list:
for ti_name, ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy()
trigger = _get_trigger(ti, i)
trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:

View File

@ -28,8 +28,6 @@ import torch
import logging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel
from .models import BaseModelType, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs
@ -188,7 +186,7 @@ class ModelCache(object):
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
# this will remove older cached models until

View File

@ -472,7 +472,7 @@ class ModelManager(object):
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.app_config.root_path / override_path
model_path = self.resolve_path(override_path)
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
@ -670,7 +670,7 @@ class ModelManager(object):
# TODO: if path changed and old_model.path inside models folder should we delete this too?
# remove conversion cache as config changed
old_model_path = self.app_config.root_path / old_model.path
old_model_path = self.resolve_model_path(old_model.path)
old_model_cache = self._get_model_cache_path(old_model_path)
if old_model_cache.exists():
if old_model_cache.is_dir():
@ -780,7 +780,7 @@ class ModelManager(object):
model_type,
**submodel,
)
checkpoint_path = self.app_config.root_path / info["path"]
checkpoint_path = self.resolve_model_path(info["path"])
old_diffusers_path = self.resolve_model_path(model.location)
new_diffusers_path = (
dest_directory or self.app_config.models_path / base_model.value / model_type.value
@ -992,7 +992,7 @@ class ModelManager(object):
model_manager=self,
prediction_type_helper=ask_user_for_prediction_type,
)
known_paths = {config.root_path / x["path"] for x in self.list_models()}
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
directories = {
config.root_path / x
for x in [

View File

@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
# misclassified as SD-1
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
if key in checkpoint and checkpoint[key].shape[0] == 320:
return BaseModelType.StableDiffusion2
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
if key in checkpoint:
return BaseModelType.StableDiffusionXL
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[0]
else checkpoint[key2].shape[1]
if key2 in checkpoint
else 768
else checkpoint[key3].shape[0]
if key3 in checkpoint
else None
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
return None
raise InvalidModelException(f"Unknown LoRA type")
class TextualInversionCheckpointProbe(CheckpointProbeBase):

View File

@ -292,8 +292,9 @@ class DiffusersModel(ModelBase):
)
break
except Exception as e:
# print("====ERR LOAD====")
# print(f"{variant}: {e}")
if not str(e).startswith("Error no file"):
print("====ERR LOAD====")
print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")

View File

@ -1,7 +1,9 @@
import os
import torch
from enum import Enum
from typing import Optional, Union, Literal
from typing import Optional, Dict, Union, Literal, Any
from pathlib import Path
from safetensors.torch import load_file
from .base import (
ModelBase,
ModelConfigBase,
@ -13,9 +15,6 @@ from .base import (
ModelNotFoundException,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
@ -50,6 +49,7 @@ class LoRAModel(ModelBase):
model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
base_model=self.base_model,
)
self.model_size = model.calc_size()
@ -87,3 +87,582 @@ class LoRAModel(ModelBase):
raise NotImplementedError("Diffusers lora not supported")
else:
return model_path
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# weight: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["diff"]
if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.rank = None # unscaled
def get_weight(self):
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_compvis_keys(cls, state_dict):
new_state_dict = dict()
for full_key, value in state_dict.items():
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
continue # clip same
if not full_key.startswith("lora_unet_"):
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
src_key = full_key.replace("lora_unet_", "")
try:
dst_key = None
while "_" in src_key:
if src_key in SDXL_UNET_COMPVIS_MAP:
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
break
src_key = "_".join(src_key.split("_")[:-1])
if dst_key is None:
raise Exception(f"Unknown sdxl lora key - {full_key}")
new_key = full_key.replace(src_key, dst_key)
except:
print(SDXL_UNET_COMPVIS_MAP)
raise
new_state_dict[new_key] = value
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
elif "diff" in values:
layer = FullLayer(layer_key, values)
else:
# TODO: ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map():
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_COMPVIS_MAP = {
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@ -4,6 +4,7 @@ from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from .base import (
ModelConfigBase,
BaseModelType,
@ -263,6 +264,8 @@ def _convert_ckpt_and_cache(
weights = app_config.models_path / model_config.path
config_file = app_config.root_path / model_config.config
output_path = Path(output_path)
variant = model_config.variant
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
# return cached version if it exists
if output_path.exists():
@ -289,6 +292,7 @@ def _convert_ckpt_and_cache(
original_config_file=config_file,
extract_ema=True,
scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=weights.suffix == ".safetensors",
precision=torch_dtype(choose_torch_device()),
**kwargs,

View File

@ -78,10 +78,9 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None
self.sequential_guidance = config.sequential_guidance
@classmethod
@contextmanager
def custom_attention_context(
cls,
self,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
@ -91,18 +90,19 @@ class InvokeAIDiffuserComponent:
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
cross_attention_control_context = Context(
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
cross_attention_control_context,
self.cross_attention_control_context,
)
try:
yield None
finally:
self.cross_attention_control_context = None
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving

View File

@ -23,7 +23,7 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"typegen": "npx ts-node scripts/typegen.ts",
"typegen": "node scripts/typegen.js",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",

View File

@ -124,7 +124,8 @@
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.",
"images": "Images",
"assets": "Assets"
"assets": "Assets",
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
},
"hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts",

View File

@ -4,8 +4,9 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
import SiteHeader from 'features/system/components/SiteHeader';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
@ -16,7 +17,6 @@ import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import i18n from 'i18n';
import { size } from 'lodash-es';
import { ReactNode, memo, useEffect } from 'react';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
@ -84,7 +84,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
</Portal>
</Grid>
<DeleteImageModal />
<UpdateImageBoardModal />
<ChangeBoardModal />
<Toaster />
<GlobalHotkeys />
</>

View File

@ -58,7 +58,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
);
}
if (props.dragData.payloadType === 'IMAGE_NAMES') {
if (props.dragData.payloadType === 'IMAGE_DTOS') {
return (
<Flex
sx={{
@ -71,7 +71,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
...STYLES,
}}
>
<Heading>{props.dragData.payload.image_names.length}</Heading>
<Heading>{props.dragData.payload.imageDTOs.length}</Heading>
<Heading size="sm">Images</Heading>
</Flex>
);

View File

@ -18,27 +18,32 @@ import {
DragStartEvent,
TypesafeDraggableData,
} from './typesafeDnd';
import { logger } from 'app/logging/logger';
type ImageDndContextProps = PropsWithChildren;
const ImageDndContext = (props: ImageDndContextProps) => {
const [activeDragData, setActiveDragData] =
useState<TypesafeDraggableData | null>(null);
const log = logger('images');
const dispatch = useAppDispatch();
const handleDragStart = useCallback((event: DragStartEvent) => {
console.log('dragStart', event.active.data.current);
const handleDragStart = useCallback(
(event: DragStartEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag started');
const activeData = event.active.data.current;
if (!activeData) {
return;
}
setActiveDragData(activeData);
}, []);
},
[log]
);
const handleDragEnd = useCallback(
(event: DragEndEvent) => {
console.log('dragEnd', event.active.data.current);
log.trace({ dragData: event.active.data.current }, 'Drag ended');
const overData = event.over?.data.current;
if (!activeDragData || !overData) {
return;
@ -46,7 +51,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
dispatch(dndDropped({ overData, activeData: activeDragData }));
setActiveDragData(null);
},
[activeDragData, dispatch]
[activeDragData, dispatch, log]
);
const mouseSensor = useSensor(MouseSensor, {

View File

@ -11,7 +11,6 @@ import {
useDraggable as useOriginalDraggable,
useDroppable as useOriginalDroppable,
} from '@dnd-kit/core';
import { BoardId } from 'features/gallery/store/types';
import { ImageDTO } from 'services/api/types';
type BaseDropData = {
@ -54,9 +53,13 @@ export type AddToBatchDropData = BaseDropData & {
actionType: 'ADD_TO_BATCH';
};
export type MoveBoardDropData = BaseDropData & {
actionType: 'MOVE_BOARD';
context: { boardId: BoardId };
export type AddToBoardDropData = BaseDropData & {
actionType: 'ADD_TO_BOARD';
context: { boardId: string };
};
export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD';
};
export type TypesafeDroppableData =
@ -67,7 +70,8 @@ export type TypesafeDroppableData =
| NodesImageDropData
| AddToBatchDropData
| NodesMultiImageDropData
| MoveBoardDropData;
| AddToBoardDropData
| RemoveFromBoardDropData;
type BaseDragData = {
id: string;
@ -78,14 +82,12 @@ export type ImageDraggableData = BaseDragData & {
payload: { imageDTO: ImageDTO };
};
export type ImageNamesDraggableData = BaseDragData & {
payloadType: 'IMAGE_NAMES';
payload: { image_names: string[] };
export type ImageDTOsDraggableData = BaseDragData & {
payloadType: 'IMAGE_DTOS';
payload: { imageDTOs: ImageDTO[] };
};
export type TypesafeDraggableData =
| ImageDraggableData
| ImageNamesDraggableData;
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> {
@ -156,14 +158,39 @@ export const isValidDrop = (
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
case 'MOVE_BOARD': {
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
case 'REMOVE_FROM_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
@ -172,20 +199,16 @@ export const isValidDrop = (
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id;
const destinationBoard = overData.context.boardId;
const isSameBoard = currentBoard === destinationBoard;
const isDestinationValid = !currentBoard ? destinationBoard : true;
return !isSameBoard && isDestinationValid;
return currentBoard !== 'none';
}
if (payloadType === 'IMAGE_NAMES') {
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return false;
return true;
}
return true;
return false;
}
default:
return false;

View File

@ -1,4 +1,6 @@
import { Middleware } from '@reduxjs/toolkit';
import { store } from 'app/store/store';
import { PartialAppConfig } from 'app/types/invokeai';
import React, {
lazy,
memo,
@ -7,16 +9,11 @@ import React, {
useEffect,
} from 'react';
import { Provider } from 'react-redux';
import { PartialAppConfig } from 'app/types/invokeai';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import Loading from '../../common/components/Loading/Loading';
import { Middleware } from '@reduxjs/toolkit';
import { $authToken, $baseUrl } from 'services/api/client';
import { $authToken, $baseUrl, $projectId } from 'services/api/client';
import { socketMiddleware } from 'services/events/middleware';
import Loading from '../../common/components/Loading/Loading';
import '../../i18n';
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
import ImageDndContext from './ImageDnd/ImageDndContext';
const App = lazy(() => import('./App'));
@ -37,6 +34,7 @@ const InvokeAIUI = ({
config,
headerComponent,
middleware,
projectId,
}: Props) => {
useEffect(() => {
// configure API client token
@ -49,6 +47,11 @@ const InvokeAIUI = ({
$baseUrl.set(apiUrl);
}
// configure API client project header
if (projectId) {
$projectId.set(projectId);
}
// reset dynamically added middlewares
resetMiddlewares();
@ -68,8 +71,9 @@ const InvokeAIUI = ({
// Reset the API client token and base url on unmount
$baseUrl.set(undefined);
$authToken.set(undefined);
$projectId.set(undefined);
};
}, [apiUrl, token, middleware]);
}, [apiUrl, token, middleware, projectId]);
return (
<React.StrictMode>
@ -77,9 +81,7 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<ImageDndContext>
<AddImageToBoardContextProvider>
<App config={config} headerComponent={headerComponent} />
</AddImageToBoardContextProvider>
</ImageDndContext>
</ThemeLocaleProvider>
</React.Suspense>

View File

@ -1,91 +0,0 @@
import { useDisclosure } from '@chakra-ui/react';
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
import { ImageDTO } from 'services/api/types';
import { imagesApi } from 'services/api/endpoints/images';
import { useAppDispatch } from '../store/storeHooks';
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};
type AddImageToBoardContextValue = {
/**
* Whether the move image dialog is open.
*/
isOpen: boolean;
/**
* Closes the move image dialog.
*/
onClose: () => void;
/**
* The image pending movement
*/
image?: ImageDTO;
onClickAddToBoard: (image: ImageDTO) => void;
handleAddToBoard: (boardId: string) => void;
};
export const AddImageToBoardContext =
createContext<AddImageToBoardContextValue>({
isOpen: false,
onClose: () => undefined,
onClickAddToBoard: () => undefined,
handleAddToBoard: () => undefined,
});
type Props = PropsWithChildren;
export const AddImageToBoardContextProvider = (props: Props) => {
const [imageToMove, setImageToMove] = useState<ImageDTO>();
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
// Clean up after deleting or dismissing the modal
const closeAndClearImageToDelete = useCallback(() => {
setImageToMove(undefined);
onClose();
}, [onClose]);
const onClickAddToBoard = useCallback(
(image?: ImageDTO) => {
if (!image) {
return;
}
setImageToMove(image);
onOpen();
},
[setImageToMove, onOpen]
);
const handleAddToBoard = useCallback(
(boardId: string) => {
if (imageToMove) {
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO: imageToMove,
board_id: boardId,
})
);
closeAndClearImageToDelete();
}
},
[dispatch, closeAndClearImageToDelete, imageToMove]
);
return (
<AddImageToBoardContext.Provider
value={{
isOpen,
image: imageToMove,
onClose: closeAndClearImageToDelete,
onClickAddToBoard,
handleAddToBoard,
}}
>
{props.children}
</AddImageToBoardContext.Provider>
);
};

View File

@ -1,8 +0,0 @@
import { createContext } from 'react';
type VoidFunc = () => void;
type ImageUploaderTriggerContextType = VoidFunc | null;
export const ImageUploaderTriggerContext =
createContext<ImageUploaderTriggerContextType>(null);

View File

@ -23,6 +23,6 @@ const serializationDenylist: {
};
export const serialize: SerializeFunction = (data, key) => {
const result = omit(data, serializationDenylist[key]);
const result = omit(data, serializationDenylist[key] ?? []);
return JSON.stringify(result);
};

View File

@ -27,7 +27,8 @@ import {
addImageDeletedFulfilledListener,
addImageDeletedPendingListener,
addImageDeletedRejectedListener,
addRequestedImageDeletionListener,
addRequestedSingleImageDeletionListener,
addRequestedMultipleImageDeletionListener,
} from './listeners/imageDeleted';
import { addImageDroppedListener } from './listeners/imageDropped';
import {
@ -111,7 +112,8 @@ addImageUploadedRejectedListener();
addInitialImageSelectedListener();
// Image deleted
addRequestedImageDeletionListener();
addRequestedSingleImageDeletionListener();
addRequestedMultipleImageDeletionListener();
addImageDeletedPendingListener();
addImageDeletedFulfilledListener();
addImageDeletedRejectedListener();

View File

@ -1,12 +1,10 @@
import { createAction } from '@reduxjs/toolkit';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import {
ImageCache,
getListImagesUrl,
imagesApi,
} from 'services/api/endpoints/images';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { getListImagesUrl, imagesAdapter } from 'services/api/util';
import { ImageCache } from 'services/api/types';
export const appStarted = createAction('app/appStarted');
@ -34,7 +32,8 @@ export const addFirstListImagesListener = () => {
if (data.ids.length > 0) {
// Select the first image
dispatch(imageSelected(data.ids[0] as string));
const firstImage = imagesAdapter.getSelectors().selectAll(data)[0];
dispatch(imageSelected(firstImage ?? null));
}
},
});

View File

@ -18,7 +18,9 @@ export const addAppConfigReceivedListener = () => {
const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) {
dispatch(setInfillMethod(infill_methods[0]));
// if there is no infill method, set it to the first one
// if there is no first one... god help us
dispatch(setInfillMethod(infill_methods[0] as string));
}
if (!nsfw_methods.includes('nsfw_checker')) {

View File

@ -1,14 +1,14 @@
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { getImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { boardsApi } from '../../../../../services/api/endpoints/boards';
export const addDeleteBoardAndImagesFulfilledListener = () => {
startAppListening({
matcher: boardsApi.endpoints.deleteBoardAndImages.matchFulfilled,
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
const { deleted_images } = action.payload;

View File

@ -10,6 +10,7 @@ import {
} from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { imagesSelectors } from 'services/api/util';
export const addBoardIdSelectedListener = () => {
startAppListening({
@ -52,8 +53,9 @@ export const addBoardIdSelectedListener = () => {
queryArgs
)(getState());
if (boardImagesData?.ids.length) {
dispatch(imageSelected((boardImagesData.ids[0] as string) ?? null));
if (boardImagesData) {
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
dispatch(imageSelected(firstImage ?? null));
} else {
// board has no images - deselect
dispatch(imageSelected(null));

View File

@ -26,6 +26,8 @@ export const addCanvasSavedToGalleryListener = () => {
return;
}
const { autoAddBoardId } = state.gallery;
dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
@ -33,7 +35,7 @@ export const addCanvasSavedToGalleryListener = () => {
}),
image_category: 'general',
is_intermediate: false,
board_id: state.gallery.autoAddBoardId,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',

View File

@ -31,15 +31,20 @@ const predicate: AnyListenerPredicate<RootState> = (
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
.shouldAutoConfig === true
?.shouldAutoConfig === true
) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId];
const cn = state.controlNet.controlNets[action.payload.controlNetId];
if (!cn) {
// something is wrong, the controlNet should exist
return false;
}
const { controlImage, processorType, shouldAutoConfig } = cn;
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty
return false;

View File

@ -17,7 +17,7 @@ export const addControlNetImageProcessedListener = () => {
const { controlNetId } = action.payload;
const controlNet = getState().controlNet.controlNets[controlNetId];
if (!controlNet.controlImage) {
if (!controlNet?.controlImage) {
log.error('Unable to process ControlNet image');
return;
}

View File

@ -1,57 +1,72 @@
import { logger } from 'app/logging/logger';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
import { isModalOpenChanged } from 'features/imageDeletion/store/imageDeletionSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { startAppListening } from '..';
/**
* Called when the user requests an image deletion
*/
export const addRequestedImageDeletionListener = () => {
export const addRequestedSingleImageDeletionListener = () => {
startAppListening({
actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState, condition }) => {
const { imageDTO, imageUsage } = action.payload;
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length !== 1 || imagesUsage.length !== 1) {
// handle multiples in separate listener
return;
}
const imageDTO = imageDTOs[0];
const imageUsage = imagesUsage[0];
if (!imageDTO || !imageUsage) {
// satisfy noUncheckedIndexedAccess
return;
}
dispatch(isModalOpenChanged(false));
const { image_name } = imageDTO;
const state = getState();
const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1];
state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const { image_name } = imageDTO;
if (lastSelectedImage === image_name) {
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
const { data } =
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const ids = data?.ids ?? [];
const cachedImageDTOs = data
? imagesAdapter.getSelectors().selectAll(data)
: [];
const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name
const deletedImageIndex = cachedImageDTOs.findIndex(
(i) => i.image_name === image_name
);
const filteredIds = ids.filter((id) => id.toString() !== image_name);
const filteredImageDTOs = cachedImageDTOs.filter(
(i) => i.image_name !== image_name
);
const newSelectedImageIndex = clamp(
deletedImageIndex,
0,
filteredIds.length - 1
filteredImageDTOs.length - 1
);
const newSelectedImageId = filteredIds[newSelectedImageIndex];
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImageId as string));
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}
@ -97,6 +112,66 @@ export const addRequestedImageDeletionListener = () => {
});
};
/**
* Called when the user requests an image deletion
*/
export const addRequestedMultipleImageDeletionListener = () => {
startAppListening({
actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
// handle singles in separate listener
return;
}
try {
// Delete from server
await dispatch(
imagesApi.endpoints.deleteImages.initiate({ imageDTOs })
).unwrap();
const state = getState();
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
const { data } =
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const newSelectedImageDTO = data
? imagesAdapter.getSelectors().selectAll(data)[0]
: undefined;
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}
dispatch(isModalOpenChanged(false));
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imagesUsage.some((i) => i.isCanvasImage)) {
dispatch(resetCanvas());
}
if (imagesUsage.some((i) => i.isControlNetImage)) {
dispatch(controlNetReset());
}
if (imagesUsage.some((i) => i.isInitialImage)) {
dispatch(clearInitialImage());
}
if (imagesUsage.some((i) => i.isNodesImage)) {
dispatch(nodeEditorReset());
}
} catch {
// no-op
}
},
});
};
/**
* Called when the actual delete request is sent to the server
*/

View File

@ -6,10 +6,7 @@ import {
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import {
imageSelected,
imagesAddedToBatch,
} from 'features/gallery/store/gallerySlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
@ -27,19 +24,32 @@ export const addImageDroppedListener = () => {
const log = logger('images');
const { activeData, overData } = action.payload;
log.debug({ activeData, overData }, 'Image or selection dropped');
if (activeData.payloadType === 'IMAGE_DTO') {
log.debug({ activeData, overData }, 'Image dropped');
} else if (activeData.payloadType === 'IMAGE_DTOS') {
log.debug(
{ activeData, overData },
`Images (${activeData.payload.imageDTOs.length}) dropped`
);
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
// set current image
/**
* Image dropped on current image
*/
if (
overData.actionType === 'SET_CURRENT_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imageSelected(activeData.payload.imageDTO.image_name));
dispatch(imageSelected(activeData.payload.imageDTO));
return;
}
// set initial image
/**
* Image dropped on initial image
*/
if (
overData.actionType === 'SET_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
@ -49,27 +59,9 @@ export const addImageDroppedListener = () => {
return;
}
// add image to batch
if (
overData.actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imagesAddedToBatch([activeData.payload.imageDTO.image_name]));
return;
}
// add multiple images to batch
if (
overData.actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_NAMES'
) {
dispatch(imagesAddedToBatch(activeData.payload.image_names));
return;
}
// set control image
/**
* Image dropped on ControlNet
*/
if (
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
@ -85,7 +77,9 @@ export const addImageDroppedListener = () => {
return;
}
// set canvas image
/**
* Image dropped on Canvas
*/
if (
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
@ -95,7 +89,9 @@ export const addImageDroppedListener = () => {
return;
}
// set nodes image
/**
* Image dropped on node image field
*/
if (
overData.actionType === 'SET_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
@ -112,61 +108,36 @@ export const addImageDroppedListener = () => {
return;
}
// set multiple nodes images (single image handler)
if (
overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldValueChanged({
nodeId,
fieldName,
value: [activeData.payload.imageDTO],
})
);
return;
}
// // set multiple nodes images (multiple images handler)
/**
* TODO
* Image selection dropped on node image collection field
*/
// if (
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
// activeData.payloadType === 'IMAGE_NAMES'
// activeData.payloadType === 'IMAGE_DTO' &&
// activeData.payload.imageDTO
// ) {
// const { fieldName, nodeId } = overData.context;
// dispatch(
// imageCollectionFieldValueChanged({
// fieldValueChanged({
// nodeId,
// fieldName,
// value: activeData.payload.image_names.map((image_name) => ({
// image_name,
// })),
// value: [activeData.payload.imageDTO],
// })
// );
// return;
// }
// add image to board
/**
* Image dropped on user board
*/
if (
overData.actionType === 'MOVE_BOARD' &&
overData.actionType === 'ADD_TO_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
const { boardId } = overData.context;
// image was droppe on the "NoBoardBoard"
if (!boardId) {
dispatch(
imagesApi.endpoints.removeImageFromBoard.initiate({
imageDTO,
})
);
return;
}
// image was dropped on a user board
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO,
@ -176,67 +147,58 @@ export const addImageDroppedListener = () => {
return;
}
// // add gallery selection to board
// if (
// overData.actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// overData.context.boardId
// ) {
// console.log('adding gallery selection to board');
// const board_id = overData.context.boardId;
// dispatch(
// boardImagesApi.endpoints.addManyBoardImages.initiate({
// board_id,
// image_names: activeData.payload.image_names,
// })
// );
// return;
// }
/**
* Image dropped on 'none' board
*/
if (
overData.actionType === 'REMOVE_FROM_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(
imagesApi.endpoints.removeImageFromBoard.initiate({
imageDTO,
})
);
return;
}
// // remove gallery selection from board
// if (
// overData.actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// overData.context.boardId === null
// ) {
// console.log('removing gallery selection to board');
// dispatch(
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
// image_names: activeData.payload.image_names,
// })
// );
// return;
// }
/**
* Multiple images dropped on user board
*/
if (
overData.actionType === 'ADD_TO_BOARD' &&
activeData.payloadType === 'IMAGE_DTOS' &&
activeData.payload.imageDTOs
) {
const { imageDTOs } = activeData.payload;
const { boardId } = overData.context;
dispatch(
imagesApi.endpoints.addImagesToBoard.initiate({
imageDTOs,
board_id: boardId,
})
);
return;
}
// // add batch selection to board
// if (
// overData.actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// overData.context.boardId
// ) {
// const board_id = overData.context.boardId;
// dispatch(
// boardImagesApi.endpoints.addManyBoardImages.initiate({
// board_id,
// image_names: activeData.payload.image_names,
// })
// );
// return;
// }
// // remove batch selection from board
// if (
// overData.actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// overData.context.boardId === null
// ) {
// dispatch(
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
// image_names: activeData.payload.image_names,
// })
// );
// return;
// }
/**
* Multiple images dropped on 'none' board
*/
if (
overData.actionType === 'REMOVE_FROM_BOARD' &&
activeData.payloadType === 'IMAGE_DTOS' &&
activeData.payload.imageDTOs
) {
const { imageDTOs } = activeData.payload;
dispatch(
imagesApi.endpoints.removeImagesFromBoard.initiate({
imageDTOs,
})
);
return;
}
},
});
};

View File

@ -1,37 +1,32 @@
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
import { selectImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { selectImageUsage } from 'features/deleteImageModal/store/selectors';
import {
imageToDeleteSelected,
imagesToDeleteSelected,
isModalOpenChanged,
} from 'features/imageDeletion/store/imageDeletionSlice';
} from 'features/deleteImageModal/store/slice';
import { startAppListening } from '..';
export const addImageToDeleteSelectedListener = () => {
startAppListening({
actionCreator: imageToDeleteSelected,
actionCreator: imagesToDeleteSelected,
effect: async (action, { dispatch, getState }) => {
const imageDTO = action.payload;
const imageDTOs = action.payload;
const state = getState();
const { shouldConfirmOnDelete } = state.system;
const imageUsage = selectImageUsage(getState());
if (!imageUsage) {
// should never happen
return;
}
const imagesUsage = selectImageUsage(getState());
const isImageInUse =
imageUsage.isCanvasImage ||
imageUsage.isInitialImage ||
imageUsage.isControlNetImage ||
imageUsage.isNodesImage;
imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isInitialImage) ||
imagesUsage.some((i) => i.isControlNetImage) ||
imagesUsage.some((i) => i.isNodesImage);
if (shouldConfirmOnDelete || isImageInUse) {
dispatch(isModalOpenChanged(true));
return;
}
dispatch(imageDeletionConfirmed({ imageDTO, imageUsage }));
dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage }));
},
});
};

View File

@ -2,14 +2,13 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { imagesAddedToBatch } from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { startAppListening } from '..';
import { imagesApi } from '../../../../../services/api/endpoints/images';
import { omit } from 'lodash-es';
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
title: 'Image Uploaded',
@ -41,7 +40,7 @@ export const addImageUploadedFulfilledListener = () => {
// default action - just upload and alert user
if (postUploadAction?.type === 'TOAST') {
const { toastOptions } = postUploadAction;
if (!autoAddBoardId) {
if (!autoAddBoardId || autoAddBoardId === 'none') {
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
} else {
// Add this image to the board
@ -121,17 +120,6 @@ export const addImageUploadedFulfilledListener = () => {
);
return;
}
if (postUploadAction?.type === 'ADD_TO_BATCH') {
dispatch(imagesAddedToBatch([imageDTO.image_name]));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: 'Added to batch',
})
);
return;
}
},
});
};

View File

@ -15,7 +15,7 @@ import {
setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
@ -144,8 +144,9 @@ export const addModelsLoadedListener = () => {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
const firstModel = vaeModelsAdapter
.getSelectors()
.selectAll(action.payload)[0];
if (!firstModel) {
// No custom VAEs loaded at all; use the default

View File

@ -8,9 +8,10 @@ import {
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { progressImageSet } from 'features/system/store/systemSlice';
import { imagesAdapter, imagesApi } from 'services/api/endpoints/images';
import { imagesApi } from 'services/api/endpoints/images';
import { isImageOutput } from 'services/api/guards';
import { sessionCanceled } from 'services/api/thunks/session';
import { imagesAdapter } from 'services/api/util';
import {
appSocketInvocationComplete,
socketInvocationComplete,
@ -67,7 +68,7 @@ export const addInvocationCompleteEventListener = () => {
*/
const { autoAddBoardId } = gallery;
if (autoAddBoardId) {
if (autoAddBoardId && autoAddBoardId !== 'none') {
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
board_id: autoAddBoardId,
@ -83,10 +84,7 @@ export const addInvocationCompleteEventListener = () => {
categories: IMAGE_CATEGORIES,
},
(draft) => {
const oldTotal = draft.total;
const newState = imagesAdapter.addOne(draft, imageDTO);
const delta = newState.total - oldTotal;
draft.total = draft.total + delta;
imagesAdapter.addOne(draft, imageDTO);
}
)
);
@ -94,8 +92,8 @@ export const addInvocationCompleteEventListener = () => {
dispatch(
imagesApi.util.invalidateTags([
{ type: 'BoardImagesTotal', id: autoAddBoardId ?? 'none' },
{ type: 'BoardAssetsTotal', id: autoAddBoardId ?? 'none' },
{ type: 'BoardImagesTotal', id: autoAddBoardId },
{ type: 'BoardAssetsTotal', id: autoAddBoardId },
])
);
@ -110,7 +108,7 @@ export const addInvocationCompleteEventListener = () => {
} else if (!autoAddBoardId) {
dispatch(galleryViewChanged('images'));
}
dispatch(imageSelected(imageDTO.image_name));
dispatch(imageSelected(imageDTO));
}
}

View File

@ -8,9 +8,9 @@ import {
import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import boardsReducer from 'features/gallery/store/boardSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice';
@ -43,9 +43,9 @@ const allReducers = {
ui: uiReducer,
hotkeys: hotkeysReducer,
controlNet: controlNetReducer,
boards: boardsReducer,
dynamicPrompts: dynamicPromptsReducer,
imageDeletion: imageDeletionReducer,
deleteImageModal: deleteImageModalReducer,
changeBoardModal: changeBoardModalReducer,
lora: loraReducer,
modelmanager: modelmanagerReducer,
sdxl: sdxlReducer,

View File

@ -1,4 +1,4 @@
import { Flex, Text, useColorMode } from '@chakra-ui/react';
import { Box, Flex, useColorMode } from '@chakra-ui/react';
import { motion } from 'framer-motion';
import { ReactNode, memo, useRef } from 'react';
import { mode } from 'theme/util/mode';
@ -74,7 +74,7 @@ export const IAIDropOverlay = (props: Props) => {
justifyContent: 'center',
}}
>
<Text
<Box
sx={{
fontSize: '2xl',
fontWeight: 600,
@ -87,7 +87,7 @@ export const IAIDropOverlay = (props: Props) => {
}}
>
{label}
</Text>
</Box>
</Flex>
</Flex>
</motion.div>

View File

@ -53,7 +53,9 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
// cannot figure out why we were doing this, but it was causing an issue where if you
// select the currently-selected item, it reset the search value to empty
// setSearchValue('');
if (!onChange) {
return;

View File

@ -78,7 +78,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
image_category: 'user',
is_intermediate: false,
postUploadAction,
board_id: autoAddBoardId,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
});
},
[autoAddBoardId, postUploadAction, uploadImage]

View File

@ -49,7 +49,7 @@ export const useImageUploadButton = ({
image_category: 'user',
is_intermediate: false,
postUploadAction: postUploadAction ?? { type: 'TOAST' },
board_id: autoAddBoardId,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
});
},
[autoAddBoardId, postUploadAction, uploadImage]

View File

@ -33,6 +33,10 @@ const useColorPicker = () => {
1
).data;
if (!(a && r && g && b)) {
return;
}
dispatch(setColorPickerColor({ r, g, b, a }));
},
commitColorUnderCursor: () => {

View File

@ -727,10 +727,13 @@ export const canvasSlice = createSlice({
state.pastLayerStates.shift();
}
state.layerState.objects.push({
...images[selectedImageIndex],
});
const imageToCommit = images[selectedImageIndex];
if (imageToCommit) {
state.layerState.objects.push({
...imageToCommit,
});
}
state.layerState.stagingArea = {
...initialLayerState.stagingArea,
};

View File

@ -0,0 +1,132 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Flex,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import {
useAddImagesToBoardMutation,
useRemoveImagesFromBoardMutation,
} from 'services/api/endpoints/images';
import { changeBoardReset, isModalOpenChanged } from '../store/slice';
const selector = createSelector(
[stateSelector],
({ changeBoardModal }) => {
const { isModalOpen, imagesToChange } = changeBoardModal;
return {
isModalOpen,
imagesToChange,
};
},
defaultSelectorOptions
);
const ChangeBoardModal = () => {
const dispatch = useAppDispatch();
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const { data: boards, isFetching } = useListAllBoardsQuery();
const { imagesToChange, isModalOpen } = useAppSelector(selector);
const [addImagesToBoard] = useAddImagesToBoardMutation();
const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation();
const data = useMemo(() => {
const data: { label: string; value: string }[] = [
{ label: 'Uncategorized', value: 'none' },
];
(boards ?? []).forEach((board) =>
data.push({
label: board.board_name,
value: board.board_id,
})
);
return data;
}, [boards]);
const handleClose = useCallback(() => {
dispatch(changeBoardReset());
dispatch(isModalOpenChanged(false));
}, [dispatch]);
const handleChangeBoard = useCallback(() => {
if (!imagesToChange.length || !selectedBoard) {
return;
}
if (selectedBoard === 'none') {
removeImagesFromBoard({ imageDTOs: imagesToChange });
} else {
addImagesToBoard({
imageDTOs: imagesToChange,
board_id: selectedBoard,
});
}
setSelectedBoard(null);
dispatch(changeBoardReset());
}, [
addImagesToBoard,
dispatch,
imagesToChange,
removeImagesFromBoard,
selectedBoard,
]);
const cancelRef = useRef<HTMLButtonElement>(null);
return (
<AlertDialog
isOpen={isModalOpen}
onClose={handleClose}
leastDestructiveRef={cancelRef}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Change Board
</AlertDialogHeader>
<AlertDialogBody>
<Flex sx={{ flexDir: 'column', gap: 4 }}>
<Text>
Moving {`${imagesToChange.length}`} image
{`${imagesToChange.length > 1 ? 's' : ''}`} to board:
</Text>
<IAIMantineSearchableSelect
placeholder={isFetching ? 'Loading...' : 'Select Board'}
disabled={isFetching}
onChange={(v) => setSelectedBoard(v)}
value={selectedBoard}
data={data}
/>
</Flex>
</AlertDialogBody>
<AlertDialogFooter>
<IAIButton ref={cancelRef} onClick={handleClose}>
Cancel
</IAIButton>
<IAIButton colorScheme="accent" onClick={handleChangeBoard} ml={3}>
Move
</IAIButton>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};
export default memo(ChangeBoardModal);

View File

@ -0,0 +1,6 @@
import { ChangeBoardModalState } from './types';
export const initialState: ChangeBoardModalState = {
isModalOpen: false,
imagesToChange: [],
};

View File

@ -0,0 +1,25 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api/types';
import { initialState } from './initialState';
const changeBoardModal = createSlice({
name: 'changeBoardModal',
initialState,
reducers: {
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isModalOpen = action.payload;
},
imagesToChangeSelected: (state, action: PayloadAction<ImageDTO[]>) => {
state.imagesToChange = action.payload;
},
changeBoardReset: (state) => {
state.imagesToChange = [];
state.isModalOpen = false;
},
},
});
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } =
changeBoardModal.actions;
export default changeBoardModal.reducer;

View File

@ -0,0 +1,6 @@
import { ImageDTO } from 'services/api/types';
export type ChangeBoardModalState = {
isModalOpen: boolean;
imagesToChange: ImageDTO[];
};

View File

@ -3,6 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
ControlNetConfig,
controlNetDuplicated,
controlNetRemoved,
controlNetToggled,
@ -27,18 +28,27 @@ import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcesso
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
type ControlNetProps = {
controlNetId: string;
controlNet: ControlNetConfig;
};
const ControlNet = (props: ControlNetProps) => {
const { controlNetId } = props;
const { controlNet } = props;
const { controlNetId } = controlNet;
const dispatch = useAppDispatch();
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
const cn = controlNet.controlNets[controlNetId];
if (!cn) {
return {
isEnabled: false,
shouldAutoConfig: false,
};
}
const { isEnabled, shouldAutoConfig } = cn;
return { isEnabled, shouldAutoConfig };
},
@ -96,7 +106,7 @@ const ControlNet = (props: ControlNetProps) => {
transitionDuration: '0.1s',
}}
>
<ParamControlNetModel controlNetId={controlNetId} />
<ParamControlNetModel controlNet={controlNet} />
</Box>
<IAIIconButton
size="sm"
@ -171,8 +181,8 @@ const ControlNet = (props: ControlNetProps) => {
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight controlNetId={controlNetId} />
<ParamControlNetBeginEnd controlNetId={controlNetId} />
<ParamControlNetWeight controlNet={controlNet} />
<ParamControlNetBeginEnd controlNet={controlNet} />
</Flex>
{!isExpanded && (
<Flex
@ -184,22 +194,22 @@ const ControlNet = (props: ControlNetProps) => {
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
<ControlNetImagePreview controlNet={controlNet} height={28} />
</Flex>
)}
</Flex>
<Flex sx={{ gap: 2 }}>
<ParamControlNetControlMode controlNetId={controlNetId} />
<ParamControlNetResizeMode controlNetId={controlNetId} />
<ParamControlNetControlMode controlNet={controlNet} />
<ParamControlNetResizeMode controlNet={controlNet} />
</Flex>
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
<ParamControlNetProcessorSelect controlNet={controlNet} />
</Flex>
{isExpanded && (
<>
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
<ControlNetProcessorComponent controlNetId={controlNetId} />
<ControlNetImagePreview controlNet={controlNet} height="392px" />
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
<ControlNetProcessorComponent controlNet={controlNet} />
</>
)}
</Flex>

View File

@ -12,50 +12,41 @@ import IAIDndImage from 'common/components/IAIDndImage';
import { memo, useCallback, useMemo, useState } from 'react';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import { controlNetImageChanged } from '../store/controlNetSlice';
import {
ControlNetConfig,
controlNetImageChanged,
} from '../store/controlNetSlice';
type Props = {
controlNetId: string;
controlNet: ControlNetConfig;
height: SystemStyleObject['h'];
};
const ControlNetImagePreview = (props: Props) => {
const { height, controlNetId } = props;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const { pendingControlImages } = controlNet;
const {
controlImage,
processedControlImage,
processorType,
isEnabled,
} = controlNet.controlNets[controlNetId];
return {
controlImageName: controlImage,
processedControlImageName: processedControlImage,
processorType,
isEnabled,
pendingControlImages,
};
},
defaultSelectorOptions
),
[controlNetId]
);
const ControlNetImagePreview = (props: Props) => {
const { height } = props;
const {
controlImageName,
processedControlImageName,
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
pendingControlImages,
isEnabled,
} = useAppSelector(selector);
controlNetId,
} = props.controlNet;
const dispatch = useAppDispatch();
const { pendingControlImages } = useAppSelector(selector);
const [isMouseOverImage, setIsMouseOverImage] = useState(false);

View File

@ -1,8 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { memo, useMemo } from 'react';
import { memo } from 'react';
import { ControlNetConfig } from '../store/controlNetSlice';
import CannyProcessor from './processors/CannyProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
import HedProcessor from './processors/HedProcessor';
@ -17,28 +14,11 @@ import PidiProcessor from './processors/PidiProcessor';
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
export type ControlNetProcessorProps = {
controlNetId: string;
controlNet: ControlNetConfig;
};
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId } = props;
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, processorNode } = useAppSelector(selector);
const { controlNetId, isEnabled, processorNode } = props.controlNet;
if (processorNode.type === 'canny_image_processor') {
return (

View File

@ -1,34 +1,19 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
import {
ControlNetConfig,
controlNetAutoConfigToggled,
} from 'features/controlNet/store/controlNetSlice';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
type Props = {
controlNetId: string;
controlNet: ControlNetConfig;
};
const ParamControlNetShouldAutoConfig = (props: Props) => {
const { controlNetId } = props;
const { controlNetId, isEnabled, shouldAutoConfig } = props.controlNet;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, shouldAutoConfig } =
controlNet.controlNets[controlNetId];
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
),
[controlNetId]
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const isBusy = useAppSelector(selectIsBusy);
const handleShouldAutoConfigChanged = useCallback(() => {

View File

@ -9,48 +9,39 @@ import {
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useAppDispatch } from 'app/store/storeHooks';
import {
ControlNetConfig,
controlNetBeginStepPctChanged,
controlNetEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
type Props = {
controlNetId: string;
controlNet: ControlNetConfig;
};
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamControlNetBeginEnd = (props: Props) => {
const { controlNetId } = props;
const { beginStepPct, endStepPct, isEnabled, controlNetId } =
props.controlNet;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { beginStepPct, endStepPct, isEnabled } =
controlNet.controlNets[controlNetId];
return { beginStepPct, endStepPct, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
const handleStepPctChanged = useCallback(
(v: number[]) => {
dispatch(
controlNetBeginStepPctChanged({ controlNetId, beginStepPct: v[0] })
controlNetBeginStepPctChanged({
controlNetId,
beginStepPct: v[0] as number,
})
);
dispatch(
controlNetEndStepPctChanged({
controlNetId,
endStepPct: v[1] as number,
})
);
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: v[1] }));
},
[controlNetId, dispatch]
);

View File

@ -1,16 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlModes,
ControlNetConfig,
controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback, useMemo } from 'react';
import { useCallback } from 'react';
type ParamControlNetControlModeProps = {
controlNetId: string;
controlNet: ControlNetConfig;
};
const CONTROL_MODE_DATA = [
@ -23,23 +21,8 @@ const CONTROL_MODE_DATA = [
export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps
) {
const { controlNetId } = props;
const { controlMode, isEnabled, controlNetId } = props.controlNet;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { controlMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { controlMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { controlMode, isEnabled } = useAppSelector(selector);
const handleControlModeChange = useCallback(
(controlMode: ControlModes) => {

View File

@ -5,7 +5,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import {
ControlNetConfig,
controlNetModelChanged,
} from 'features/controlNet/store/controlNetSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
import { selectIsBusy } from 'features/system/store/systemSelectors';
@ -14,30 +17,24 @@ import { memo, useCallback, useMemo } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
type ParamControlNetModelProps = {
controlNetId: string;
controlNet: ControlNetConfig;
};
const selector = createSelector(
stateSelector,
({ generation }) => {
const { model } = generation;
return { mainModel: model };
},
defaultSelectorOptions
);
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId } = props;
const { controlNetId, model: controlNetModel, isEnabled } = props.controlNet;
const dispatch = useAppDispatch();
const isBusy = useAppSelector(selectIsBusy);
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ generation, controlNet }) => {
const { model } = generation;
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
return { mainModel: model, controlNetModel, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
const { mainModel } = useAppSelector(selector);
const { data: controlNetModels } = useGetControlNetModelsQuery();

View File

@ -1,7 +1,6 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType,
@ -9,13 +8,16 @@ import IAIMantineSearchableSelect, {
import { configSelector } from 'features/system/store/configSelectors';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import {
ControlNetConfig,
controlNetProcessorTypeChanged,
} from '../../store/controlNetSlice';
import { ControlNetProcessorType } from '../../store/types';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
controlNet: ControlNetConfig;
};
const selector = createSelector(
@ -52,23 +54,9 @@ const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
) => {
const dispatch = useAppDispatch();
const { controlNetId } = props;
const processorNodeSelector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { isEnabled, processorNode } =
controlNet.controlNets[controlNetId];
return { isEnabled, processorNode };
},
defaultSelectorOptions
),
[controlNetId]
);
const { controlNetId, isEnabled, processorNode } = props.controlNet;
const isBusy = useAppSelector(selectIsBusy);
const controlNetProcessors = useAppSelector(selector);
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
const handleProcessorTypeChanged = useCallback(
(v: string | null) => {

View File

@ -1,16 +1,14 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlNetConfig,
ResizeModes,
controlNetResizeModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback, useMemo } from 'react';
import { useCallback } from 'react';
type ParamControlNetResizeModeProps = {
controlNetId: string;
controlNet: ControlNetConfig;
};
const RESIZE_MODE_DATA = [
@ -22,23 +20,8 @@ const RESIZE_MODE_DATA = [
export default function ParamControlNetResizeMode(
props: ParamControlNetResizeModeProps
) {
const { controlNetId } = props;
const { resizeMode, isEnabled, controlNetId } = props.controlNet;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { resizeMode, isEnabled } =
controlNet.controlNets[controlNetId];
return { resizeMode, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { resizeMode, isEnabled } = useAppSelector(selector);
const handleResizeModeChange = useCallback(
(resizeMode: ResizeModes) => {

View File

@ -1,32 +1,18 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useAppDispatch } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback, useMemo } from 'react';
import {
ControlNetConfig,
controlNetWeightChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
type ParamControlNetWeightProps = {
controlNetId: string;
controlNet: ControlNetConfig;
};
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
const { controlNetId } = props;
const { weight, isEnabled, controlNetId } = props.controlNet;
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlNet }) => {
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
return { weight, isEnabled };
},
defaultSelectorOptions
),
[controlNetId]
);
const { weight, isEnabled } = useAppSelector(selector);
const handleWeightChanged = useCallback(
(weight: number) => {
dispatch(controlNetWeightChanged({ controlNetId, weight }));

View File

@ -4,7 +4,7 @@ import {
} from './types';
type ControlNetProcessorsDict = Record<
string,
ControlNetProcessorType,
{
type: ControlNetProcessorType | 'none';
label: string;

View File

@ -96,8 +96,11 @@ export const controlNetSlice = createSlice({
}>
) => {
const { sourceControlNetId, newControlNetId } = action.payload;
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
const oldControlNet = state.controlNets[sourceControlNetId];
if (!oldControlNet) {
return;
}
const newControlnet = cloneDeep(oldControlNet);
newControlnet.controlNetId = newControlNetId;
state.controlNets[newControlNetId] = newControlnet;
},
@ -124,8 +127,11 @@ export const controlNetSlice = createSlice({
action: PayloadAction<{ controlNetId: string }>
) => {
const { controlNetId } = action.payload;
state.controlNets[controlNetId].isEnabled =
!state.controlNets[controlNetId].isEnabled;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.isEnabled = !cn.isEnabled;
},
controlNetImageChanged: (
state,
@ -135,12 +141,14 @@ export const controlNetSlice = createSlice({
}>
) => {
const { controlNetId, controlImage } = action.payload;
state.controlNets[controlNetId].controlImage = controlImage;
state.controlNets[controlNetId].processedControlImage = null;
if (
controlImage !== null &&
state.controlNets[controlNetId].processorType !== 'none'
) {
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.controlImage = controlImage;
cn.processedControlImage = null;
if (controlImage !== null && cn.processorType !== 'none') {
state.pendingControlImages.push(controlNetId);
}
},
@ -152,8 +160,12 @@ export const controlNetSlice = createSlice({
}>
) => {
const { controlNetId, processedControlImage } = action.payload;
state.controlNets[controlNetId].processedControlImage =
processedControlImage;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.processedControlImage = processedControlImage;
state.pendingControlImages = state.pendingControlImages.filter(
(id) => id !== controlNetId
);
@ -166,10 +178,15 @@ export const controlNetSlice = createSlice({
}>
) => {
const { controlNetId, model } = action.payload;
state.controlNets[controlNetId].model = model;
state.controlNets[controlNetId].processedControlImage = null;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
if (state.controlNets[controlNetId].shouldAutoConfig) {
cn.model = model;
cn.processedControlImage = null;
if (cn.shouldAutoConfig) {
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
@ -180,14 +197,13 @@ export const controlNetSlice = createSlice({
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType
].default as RequiredControlNetProcessorNode;
cn.processorType = processorType;
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlNetProcessorNode;
} else {
state.controlNets[controlNetId].processorType = 'none';
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS
.none.default as RequiredControlNetProcessorNode;
cn.processorType = 'none';
cn.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlNetProcessorNode;
}
}
},
@ -196,28 +212,48 @@ export const controlNetSlice = createSlice({
action: PayloadAction<{ controlNetId: string; weight: number }>
) => {
const { controlNetId, weight } = action.payload;
state.controlNets[controlNetId].weight = weight;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.weight = weight;
},
controlNetBeginStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; beginStepPct: number }>
) => {
const { controlNetId, beginStepPct } = action.payload;
state.controlNets[controlNetId].beginStepPct = beginStepPct;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.beginStepPct = beginStepPct;
},
controlNetEndStepPctChanged: (
state,
action: PayloadAction<{ controlNetId: string; endStepPct: number }>
) => {
const { controlNetId, endStepPct } = action.payload;
state.controlNets[controlNetId].endStepPct = endStepPct;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.endStepPct = endStepPct;
},
controlNetControlModeChanged: (
state,
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
) => {
const { controlNetId, controlMode } = action.payload;
state.controlNets[controlNetId].controlMode = controlMode;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.controlMode = controlMode;
},
controlNetResizeModeChanged: (
state,
@ -227,7 +263,12 @@ export const controlNetSlice = createSlice({
}>
) => {
const { controlNetId, resizeMode } = action.payload;
state.controlNets[controlNetId].resizeMode = resizeMode;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.resizeMode = resizeMode;
},
controlNetProcessorParamsChanged: (
state,
@ -240,12 +281,17 @@ export const controlNetSlice = createSlice({
}>
) => {
const { controlNetId, changes } = action.payload;
const processorNode = state.controlNets[controlNetId].processorNode;
state.controlNets[controlNetId].processorNode = {
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
const processorNode = cn.processorNode;
cn.processorNode = {
...processorNode,
...changes,
};
state.controlNets[controlNetId].shouldAutoConfig = false;
cn.shouldAutoConfig = false;
},
controlNetProcessorTypeChanged: (
state,
@ -255,12 +301,16 @@ export const controlNetSlice = createSlice({
}>
) => {
const { controlNetId, processorType } = action.payload;
state.controlNets[controlNetId].processedControlImage = null;
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType
].default as RequiredControlNetProcessorNode;
state.controlNets[controlNetId].shouldAutoConfig = false;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
cn.processedControlImage = null;
cn.processorType = processorType;
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlNetProcessorNode;
cn.shouldAutoConfig = false;
},
controlNetAutoConfigToggled: (
state,
@ -269,37 +319,36 @@ export const controlNetSlice = createSlice({
}>
) => {
const { controlNetId } = action.payload;
const newShouldAutoConfig =
!state.controlNets[controlNetId].shouldAutoConfig;
const cn = state.controlNets[controlNetId];
if (!cn) {
return;
}
const newShouldAutoConfig = !cn.shouldAutoConfig;
if (newShouldAutoConfig) {
// manage the processor for the user
let processorType: ControlNetProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
if (
state.controlNets[controlNetId].model?.model_name.includes(
modelSubstring
)
) {
if (cn.model?.model_name.includes(modelSubstring)) {
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
processorType
].default as RequiredControlNetProcessorNode;
cn.processorType = processorType;
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlNetProcessorNode;
} else {
state.controlNets[controlNetId].processorType = 'none';
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS
.none.default as RequiredControlNetProcessorNode;
cn.processorType = 'none';
cn.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlNetProcessorNode;
}
}
state.controlNets[controlNetId].shouldAutoConfig = newShouldAutoConfig;
cn.shouldAutoConfig = newShouldAutoConfig;
},
controlNetReset: () => {
return { ...initialControlNetState };
@ -307,9 +356,11 @@ export const controlNetSlice = createSlice({
},
extraReducers: (builder) => {
builder.addCase(controlNetImageProcessed, (state, action) => {
if (
state.controlNets[action.payload.controlNetId].controlImage !== null
) {
const cn = state.controlNets[action.payload.controlNetId];
if (!cn) {
return;
}
if (cn.controlImage !== null) {
state.pendingControlImages.push(action.payload.controlNetId);
}
});

View File

@ -15,30 +15,42 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice';
import { stateSelector } from 'app/store/store';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { imageDeletionConfirmed } from '../store/actions';
import { selectImageUsage } from '../store/imageDeletionSelectors';
import {
imageToDeleteCleared,
isModalOpenChanged,
} from '../store/imageDeletionSlice';
import { getImageUsage, selectImageUsage } from '../store/selectors';
import { imageDeletionCanceled, isModalOpenChanged } from '../store/slice';
import ImageUsageMessage from './ImageUsageMessage';
import { ImageUsage } from '../store/types';
const selector = createSelector(
[stateSelector, selectImageUsage],
({ system, config, imageDeletion }, imageUsage) => {
(state, imagesUsage) => {
const { system, config, deleteImageModal } = state;
const { shouldConfirmOnDelete } = system;
const { canRestoreDeletedImagesFromBin } = config;
const { imageToDelete, isModalOpen } = imageDeletion;
const { imagesToDelete, isModalOpen } = deleteImageModal;
const allImageUsage = (imagesToDelete ?? []).map(({ image_name }) =>
getImageUsage(state, image_name)
);
const imageUsageSummary: ImageUsage = {
isInitialImage: some(allImageUsage, (i) => i.isInitialImage),
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
};
return {
shouldConfirmOnDelete,
canRestoreDeletedImagesFromBin,
imageToDelete,
imageUsage,
imagesToDelete,
imagesUsage,
isModalOpen,
imageUsageSummary,
};
},
defaultSelectorOptions
@ -51,9 +63,10 @@ const DeleteImageModal = () => {
const {
shouldConfirmOnDelete,
canRestoreDeletedImagesFromBin,
imageToDelete,
imageUsage,
imagesToDelete,
imagesUsage,
isModalOpen,
imageUsageSummary,
} = useAppSelector(selector);
const handleChangeShouldConfirmOnDelete = useCallback(
@ -63,17 +76,19 @@ const DeleteImageModal = () => {
);
const handleClose = useCallback(() => {
dispatch(imageToDeleteCleared());
dispatch(imageDeletionCanceled());
dispatch(isModalOpenChanged(false));
}, [dispatch]);
const handleDelete = useCallback(() => {
if (!imageToDelete || !imageUsage) {
if (!imagesToDelete.length || !imagesUsage.length) {
return;
}
dispatch(imageToDeleteCleared());
dispatch(imageDeletionConfirmed({ imageDTO: imageToDelete, imageUsage }));
}, [dispatch, imageToDelete, imageUsage]);
dispatch(imageDeletionCanceled());
dispatch(
imageDeletionConfirmed({ imageDTOs: imagesToDelete, imagesUsage })
);
}, [dispatch, imagesToDelete, imagesUsage]);
const cancelRef = useRef<HTMLButtonElement>(null);
@ -92,7 +107,7 @@ const DeleteImageModal = () => {
<AlertDialogBody>
<Flex direction="column" gap={3}>
<ImageUsageMessage imageUsage={imageUsage} />
<ImageUsageMessage imageUsage={imageUsageSummary} />
<Divider />
<Text>
{canRestoreDeletedImagesFromBin

View File

@ -3,6 +3,6 @@ import { ImageDTO } from 'services/api/types';
import { ImageUsage } from './types';
export const imageDeletionConfirmed = createAction<{
imageDTO: ImageDTO;
imageUsage: ImageUsage;
}>('imageDeletion/imageDeletionConfirmed');
imageDTOs: ImageDTO[];
imagesUsage: ImageUsage[];
}>('deleteImageModal/imageDeletionConfirmed');

View File

@ -0,0 +1,6 @@
import { DeleteImageState } from './types';
export const initialDeleteImageState: DeleteImageState = {
imagesToDelete: [],
isModalOpen: false,
};

View File

@ -39,17 +39,17 @@ export const getImageUsage = (state: RootState, image_name: string) => {
export const selectImageUsage = createSelector(
[(state: RootState) => state],
(state) => {
const { imageToDelete } = state.imageDeletion;
const { imagesToDelete } = state.deleteImageModal;
if (!imageToDelete) {
return;
if (!imagesToDelete.length) {
return [];
}
const { image_name } = imageToDelete;
const imagesUsage = imagesToDelete.map((i) =>
getImageUsage(state, i.image_name)
);
const imageUsage = getImageUsage(state, image_name);
return imageUsage;
return imagesUsage;
},
defaultSelectorOptions
);

View File

@ -0,0 +1,28 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api/types';
import { initialDeleteImageState } from './initialState';
const deleteImageModal = createSlice({
name: 'deleteImageModal',
initialState: initialDeleteImageState,
reducers: {
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isModalOpen = action.payload;
},
imagesToDeleteSelected: (state, action: PayloadAction<ImageDTO[]>) => {
state.imagesToDelete = action.payload;
},
imageDeletionCanceled: (state) => {
state.imagesToDelete = [];
state.isModalOpen = false;
},
},
});
export const {
isModalOpenChanged,
imagesToDeleteSelected,
imageDeletionCanceled,
} = deleteImageModal.actions;
export default deleteImageModal.reducer;

View File

@ -0,0 +1,13 @@
import { ImageDTO } from 'services/api/types';
export type DeleteImageState = {
imagesToDelete: ImageDTO[];
isModalOpen: boolean;
};
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};

View File

@ -11,11 +11,14 @@ import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
const selector = createSelector(
[stateSelector],
({ gallery }) => {
const { autoAddBoardId } = gallery;
({ gallery, system }) => {
const { autoAddBoardId, autoAssignBoardOnClick } = gallery;
const { isProcessing } = system;
return {
autoAddBoardId,
autoAssignBoardOnClick,
isProcessing,
};
},
defaultSelectorOptions
@ -23,7 +26,8 @@ const selector = createSelector(
const BoardAutoAddSelect = () => {
const dispatch = useAppDispatch();
const { autoAddBoardId } = useAppSelector(selector);
const { autoAddBoardId, autoAssignBoardOnClick, isProcessing } =
useAppSelector(selector);
const inputRef = useRef<HTMLInputElement>(null);
const { boards, hasBoards } = useListAllBoardsQuery(undefined, {
selectFromResult: ({ data }) => {
@ -52,7 +56,7 @@ const BoardAutoAddSelect = () => {
return;
}
dispatch(autoAddBoardIdChanged(v === 'none' ? undefined : v));
dispatch(autoAddBoardIdChanged(v));
},
[dispatch]
);
@ -67,7 +71,7 @@ const BoardAutoAddSelect = () => {
data={boards}
nothingFound="No matching Boards"
itemComponent={IAIMantineSelectItemWithTooltip}
disabled={!hasBoards}
disabled={!hasBoards || autoAssignBoardOnClick || isProcessing}
filter={(value, item: SelectItem) =>
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim())

View File

@ -11,10 +11,11 @@ import { BoardDTO } from 'services/api/types';
import { menuListMotionProps } from 'theme/components/menu';
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
import NoBoardContextMenuItems from './NoBoardContextMenuItems';
import { BoardId } from 'features/gallery/store/types';
type Props = {
board?: BoardDTO;
board_id?: string;
board_id: BoardId;
children: ContextMenuProps<HTMLDivElement>['children'];
setBoardToDelete?: (board?: BoardDTO) => void;
};
@ -25,14 +26,17 @@ const BoardContextMenu = memo(
const selector = useMemo(
() =>
createSelector(stateSelector, ({ gallery }) => {
createSelector(stateSelector, ({ gallery, system }) => {
const isAutoAdd = gallery.autoAddBoardId === board_id;
return { isAutoAdd };
const isProcessing = system.isProcessing;
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
}),
[board_id]
);
const { isAutoAdd } = useAppSelector(selector);
const { isAutoAdd, isProcessing, autoAssignBoardOnClick } =
useAppSelector(selector);
const boardName = useBoardName(board_id);
const handleSetAutoAdd = useCallback(() => {
@ -59,7 +63,7 @@ const BoardContextMenu = memo(
<MenuGroup title={boardName}>
<MenuItem
icon={<FaPlus />}
isDisabled={isAutoAdd}
isDisabled={isAutoAdd || isProcessing || autoAssignBoardOnClick}
onClick={handleSetAutoAdd}
>
Auto-add to this Board

View File

@ -1,43 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { AddToBatchDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import { useCallback } from 'react';
import { FaLayerGroup } from 'react-icons/fa';
import { useDispatch } from 'react-redux';
import GenericBoard from './GenericBoard';
const selector = createSelector(stateSelector, (state) => {
return {
count: state.gallery.batchImageNames.length,
};
});
const BatchBoard = ({ isSelected }: { isSelected: boolean }) => {
const dispatch = useDispatch();
const { count } = useAppSelector(selector);
const handleBatchBoardClick = useCallback(() => {
dispatch(boardIdSelected('batch'));
}, [dispatch]);
const droppableData: AddToBatchDropData = {
id: 'batch-board',
actionType: 'ADD_TO_BATCH',
};
return (
<GenericBoard
board_id="batch"
droppableData={droppableData}
onClick={handleBatchBoardClick}
isSelected={isSelected}
icon={FaLayerGroup}
label="Batch"
badgeCount={count}
/>
);
};
export default BatchBoard;

View File

@ -15,10 +15,9 @@ import NoBoardBoard from './NoBoardBoard';
const selector = createSelector(
[stateSelector],
({ boards, gallery }) => {
const { searchText } = boards;
const { selectedBoardId } = gallery;
return { selectedBoardId, searchText };
({ gallery }) => {
const { selectedBoardId, boardSearchText } = gallery;
return { selectedBoardId, boardSearchText };
},
defaultSelectorOptions
);
@ -29,11 +28,11 @@ type Props = {
const BoardsList = (props: Props) => {
const { isOpen } = props;
const { selectedBoardId, searchText } = useAppSelector(selector);
const { selectedBoardId, boardSearchText } = useAppSelector(selector);
const { data: boards } = useListAllBoardsQuery();
const filteredBoards = searchText
const filteredBoards = boardSearchText
? boards?.filter((board) =>
board.board_name.toLowerCase().includes(searchText.toLowerCase())
board.board_name.toLowerCase().includes(boardSearchText.toLowerCase())
)
: boards;
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
@ -75,7 +74,7 @@ const BoardsList = (props: Props) => {
}}
>
<GridItem sx={{ p: 1.5 }}>
<NoBoardBoard isSelected={selectedBoardId === undefined} />
<NoBoardBoard isSelected={selectedBoardId === 'none'} />
</GridItem>
{filteredBoards &&
filteredBoards.map((board) => (

View File

@ -9,7 +9,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { setBoardSearchText } from 'features/gallery/store/boardSlice';
import { boardSearchTextChanged } from 'features/gallery/store/gallerySlice';
import {
ChangeEvent,
KeyboardEvent,
@ -21,27 +21,27 @@ import {
const selector = createSelector(
[stateSelector],
({ boards }) => {
const { searchText } = boards;
return { searchText };
({ gallery }) => {
const { boardSearchText } = gallery;
return { boardSearchText };
},
defaultSelectorOptions
);
const BoardsSearch = () => {
const dispatch = useAppDispatch();
const { searchText } = useAppSelector(selector);
const { boardSearchText } = useAppSelector(selector);
const inputRef = useRef<HTMLInputElement>(null);
const handleBoardSearch = useCallback(
(searchTerm: string) => {
dispatch(setBoardSearchText(searchTerm));
dispatch(boardSearchTextChanged(searchTerm));
},
[dispatch]
);
const clearBoardSearch = useCallback(() => {
dispatch(setBoardSearchText(''));
dispatch(boardSearchTextChanged(''));
}, [dispatch]);
const handleKeydown = useCallback(
@ -74,11 +74,11 @@ const BoardsSearch = () => {
<Input
ref={inputRef}
placeholder="Search Boards..."
value={searchText}
value={boardSearchText}
onKeyDown={handleKeydown}
onChange={handleChange}
/>
{searchText && searchText.length && (
{boardSearchText && boardSearchText.length && (
<InputRightElement>
<IconButton
onClick={clearBoardSearch}

View File

@ -7,19 +7,27 @@ import {
Icon,
Image,
Text,
Tooltip,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { MoveBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIDroppable from 'common/components/IAIDroppable';
import SelectionOverlay from 'common/components/SelectionOverlay';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import {
autoAddBoardIdChanged,
boardIdSelected,
} from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react';
import { FaUser } from 'react-icons/fa';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import {
useGetBoardAssetsTotalQuery,
useGetBoardImagesTotalQuery,
useUpdateBoardMutation,
} from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types';
import AutoAddIcon from '../AutoAddIcon';
@ -38,18 +46,25 @@ const GalleryBoard = memo(
() =>
createSelector(
stateSelector,
({ gallery }) => {
({ gallery, system }) => {
const isSelectedForAutoAdd =
board.board_id === gallery.autoAddBoardId;
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
const isProcessing = system.isProcessing;
return { isSelectedForAutoAdd };
return {
isSelectedForAutoAdd,
autoAssignBoardOnClick,
isProcessing,
};
},
defaultSelectorOptions
),
[board.board_id]
);
const { isSelectedForAutoAdd } = useAppSelector(selector);
const { isSelectedForAutoAdd, autoAssignBoardOnClick, isProcessing } =
useAppSelector(selector);
const [isHovered, setIsHovered] = useState(false);
const handleMouseOver = useCallback(() => {
setIsHovered(true);
@ -57,6 +72,18 @@ const GalleryBoard = memo(
const handleMouseOut = useCallback(() => {
setIsHovered(false);
}, []);
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
const tooltip = useMemo(() => {
if (!imagesTotal || !assetsTotal) {
return undefined;
}
return `${imagesTotal} image${
imagesTotal > 1 ? 's' : ''
}, ${assetsTotal} asset${assetsTotal > 1 ? 's' : ''}`;
}, [assetsTotal, imagesTotal]);
const { currentData: coverImage } = useGetImageDTOQuery(
board.cover_image_name ?? skipToken
);
@ -66,15 +93,18 @@ const GalleryBoard = memo(
const handleSelectBoard = useCallback(() => {
dispatch(boardIdSelected(board_id));
}, [board_id, dispatch]);
if (autoAssignBoardOnClick && !isProcessing) {
dispatch(autoAddBoardIdChanged(board_id));
}
}, [board_id, autoAssignBoardOnClick, isProcessing, dispatch]);
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
useUpdateBoardMutation();
const droppableData: MoveBoardDropData = useMemo(
const droppableData: AddToBoardDropData = useMemo(
() => ({
id: board_id,
actionType: 'MOVE_BOARD',
actionType: 'ADD_TO_BOARD',
context: { boardId: board_id },
}),
[board_id]
@ -135,6 +165,7 @@ const GalleryBoard = memo(
setBoardToDelete={setBoardToDelete}
>
{(ref) => (
<Tooltip label={tooltip} openDelay={1000} hasArrow>
<Flex
ref={ref}
onClick={handleSelectBoard}
@ -265,6 +296,7 @@ const GalleryBoard = memo(
dropLabel={<Text fontSize="md">Move</Text>}
/>
</Flex>
</Tooltip>
)}
</BoardContextMenu>
</Flex>

View File

@ -1,50 +1,60 @@
import { Box, Flex, Image, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { MoveBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import InvokeAILogoImage from 'assets/images/logo.png';
import IAIDroppable from 'common/components/IAIDroppable';
import SelectionOverlay from 'common/components/SelectionOverlay';
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
import {
boardIdSelected,
autoAddBoardIdChanged,
} from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';
import AutoAddIcon from '../AutoAddIcon';
import BoardContextMenu from '../BoardContextMenu';
interface Props {
isSelected: boolean;
}
const selector = createSelector(
stateSelector,
({ gallery }) => {
const { autoAddBoardId } = gallery;
return { autoAddBoardId };
({ gallery, system }) => {
const { autoAddBoardId, autoAssignBoardOnClick } = gallery;
const { isProcessing } = system;
return { autoAddBoardId, autoAssignBoardOnClick, isProcessing };
},
defaultSelectorOptions
);
const NoBoardBoard = memo(({ isSelected }: Props) => {
const dispatch = useAppDispatch();
const { autoAddBoardId } = useAppSelector(selector);
const boardName = useBoardName(undefined);
const { autoAddBoardId, autoAssignBoardOnClick, isProcessing } =
useAppSelector(selector);
const boardName = useBoardName('none');
const handleSelectBoard = useCallback(() => {
dispatch(boardIdSelected(undefined));
}, [dispatch]);
dispatch(boardIdSelected('none'));
if (autoAssignBoardOnClick && !isProcessing) {
dispatch(autoAddBoardIdChanged('none'));
}
}, [dispatch, autoAssignBoardOnClick, isProcessing]);
const [isHovered, setIsHovered] = useState(false);
const handleMouseOver = useCallback(() => {
setIsHovered(true);
}, []);
const handleMouseOut = useCallback(() => {
setIsHovered(false);
}, []);
const droppableData: MoveBoardDropData = useMemo(
const droppableData: RemoveFromBoardDropData = useMemo(
() => ({
id: 'no_board',
actionType: 'MOVE_BOARD',
context: { boardId: undefined },
actionType: 'REMOVE_FROM_BOARD',
}),
[]
);
@ -64,7 +74,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
h: 'full',
}}
>
<BoardContextMenu>
<BoardContextMenu board_id="none">
{(ref) => (
<Flex
ref={ref}
@ -91,17 +101,6 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
alignItems: 'center',
}}
>
{/* <Icon
boxSize={12}
as={FaBucket}
sx={{
opacity: 0.7,
color: 'base.500',
_dark: {
color: 'base.500',
},
}}
/> */}
<Image
src={InvokeAILogoImage}
alt="invoke-ai-logo"
@ -117,19 +116,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
}}
/>
</Flex>
{/* <Flex
sx={{
position: 'absolute',
insetInlineEnd: 0,
top: 0,
p: 1,
}}
>
<Badge variant="solid" sx={BASE_BADGE_STYLES}>
{totalImages}/{totalAssets}
</Badge>
</Flex> */}
{!autoAddBoardId && <AutoAddIcon />}
{autoAddBoardId === 'none' && <AutoAddIcon />}
<Flex
sx={{
position: 'absolute',

View File

@ -11,20 +11,20 @@ import {
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { ImageUsage } from 'app/contexts/AddImageToBoardContext';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import ImageUsageMessage from 'features/imageDeletion/components/ImageUsageMessage';
import { getImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
import ImageUsageMessage from 'features/deleteImageModal/components/ImageUsageMessage';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { ImageUsage } from 'features/deleteImageModal/store/types';
import { some } from 'lodash-es';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { useListAllImageNamesForBoardQuery } from 'services/api/endpoints/boards';
import {
useDeleteBoardAndImagesMutation,
useDeleteBoardMutation,
useListAllImageNamesForBoardQuery,
} from 'services/api/endpoints/boards';
} from 'services/api/endpoints/images';
import { BoardDTO } from 'services/api/types';
type Props = {
@ -32,7 +32,7 @@ type Props = {
setBoardToDelete: (board?: BoardDTO) => void;
};
const DeleteImageModal = (props: Props) => {
const DeleteBoardModal = (props: Props) => {
const { boardToDelete, setBoardToDelete } = props;
const { t } = useTranslation();
const canRestoreDeletedImagesFromBin = useAppSelector(
@ -49,13 +49,10 @@ const DeleteImageModal = (props: Props) => {
);
const imageUsageSummary: ImageUsage = {
isInitialImage: some(allImageUsage, (usage) => usage.isInitialImage),
isCanvasImage: some(allImageUsage, (usage) => usage.isCanvasImage),
isNodesImage: some(allImageUsage, (usage) => usage.isNodesImage),
isControlNetImage: some(
allImageUsage,
(usage) => usage.isControlNetImage
),
isInitialImage: some(allImageUsage, (i) => i.isInitialImage),
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
};
return { imageUsageSummary };
}),
@ -176,4 +173,4 @@ const DeleteImageModal = (props: Props) => {
);
};
export default memo(DeleteImageModal);
export default memo(DeleteBoardModal);

View File

@ -1,93 +0,0 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Box,
Flex,
Spinner,
Text,
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { memo, useContext, useRef, useState } from 'react';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
const UpdateImageBoardModal = () => {
// const boards = useSelector(selectBoardsAll);
const { data: boards, isFetching } = useListAllBoardsQuery();
const { isOpen, onClose, handleAddToBoard, image } = useContext(
AddImageToBoardContext
);
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const cancelRef = useRef<HTMLButtonElement>(null);
const currentBoard = boards?.find(
(board) => board.board_id === image?.board_id
);
return (
<AlertDialog
isOpen={isOpen}
leastDestructiveRef={cancelRef}
onClose={onClose}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
{currentBoard ? 'Move Image to Board' : 'Add Image to Board'}
</AlertDialogHeader>
<AlertDialogBody>
<Box>
<Flex direction="column" gap={3}>
{currentBoard && (
<Text>
Moving this image from{' '}
<strong>{currentBoard.board_name}</strong> to
</Text>
)}
{isFetching ? (
<Spinner />
) : (
<IAIMantineSearchableSelect
placeholder="Select Board"
onChange={(v) => setSelectedBoard(v)}
value={selectedBoard}
data={(boards ?? []).map((board) => ({
label: board.board_name,
value: board.board_id,
}))}
/>
)}
</Flex>
</Box>
</AlertDialogBody>
<AlertDialogFooter>
<IAIButton onClick={onClose}>Cancel</IAIButton>
<IAIButton
isDisabled={!selectedBoard}
colorScheme="accent"
onClick={() => {
if (selectedBoard) {
handleAddToBoard(selectedBoard);
}
}}
ml={3}
>
{currentBoard ? 'Move' : 'Add'}
</IAIButton>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};
export default memo(UpdateImageBoardModal);

Some files were not shown because too many files have changed in this diff Show More