mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into bugfix/fp16-models
This commit is contained in:
commit
9bacd77a79
14
.github/workflows/style-checks.yml
vendored
14
.github/workflows/style-checks.yml
vendored
@ -1,13 +1,14 @@
|
|||||||
name: Black # TODO: add isort and flake8 later
|
name: style checks
|
||||||
|
# just formatting for now
|
||||||
|
# TODO: add isort and flake8 later
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request: {}
|
pull_request:
|
||||||
push:
|
push:
|
||||||
branches: master
|
branches: main
|
||||||
tags: "*"
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
black:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
@ -19,8 +20,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies with pip
|
- name: Install dependencies with pip
|
||||||
run: |
|
run: |
|
||||||
pip install --upgrade pip wheel
|
pip install black
|
||||||
pip install .[test]
|
|
||||||
|
|
||||||
# - run: isort --check-only .
|
# - run: isort --check-only .
|
||||||
- run: black --check .
|
- run: black --check .
|
||||||
|
50
.github/workflows/test-invoke-pip-skip.yml
vendored
50
.github/workflows/test-invoke-pip-skip.yml
vendored
@ -1,50 +0,0 @@
|
|||||||
name: Test invoke.py pip
|
|
||||||
|
|
||||||
# This is a dummy stand-in for the actual tests
|
|
||||||
# we don't need to run python tests on non-Python changes
|
|
||||||
# But PRs require passing tests to be mergeable
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
paths:
|
|
||||||
- '**'
|
|
||||||
- '!pyproject.toml'
|
|
||||||
- '!invokeai/**'
|
|
||||||
- '!tests/**'
|
|
||||||
- 'invokeai/frontend/web/**'
|
|
||||||
merge_group:
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
matrix:
|
|
||||||
if: github.event.pull_request.draft == false
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
python-version:
|
|
||||||
- '3.10'
|
|
||||||
pytorch:
|
|
||||||
- linux-cuda-11_7
|
|
||||||
- linux-rocm-5_2
|
|
||||||
- linux-cpu
|
|
||||||
- macos-default
|
|
||||||
- windows-cpu
|
|
||||||
include:
|
|
||||||
- pytorch: linux-cuda-11_7
|
|
||||||
os: ubuntu-22.04
|
|
||||||
- pytorch: linux-rocm-5_2
|
|
||||||
os: ubuntu-22.04
|
|
||||||
- pytorch: linux-cpu
|
|
||||||
os: ubuntu-22.04
|
|
||||||
- pytorch: macos-default
|
|
||||||
os: macOS-12
|
|
||||||
- pytorch: windows-cpu
|
|
||||||
os: windows-2022
|
|
||||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
|
||||||
runs-on: ${{ matrix.os }}
|
|
||||||
steps:
|
|
||||||
- name: skip
|
|
||||||
run: echo "no build required"
|
|
24
.github/workflows/test-invoke-pip.yml
vendored
24
.github/workflows/test-invoke-pip.yml
vendored
@ -3,16 +3,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
paths:
|
|
||||||
- 'pyproject.toml'
|
|
||||||
- 'invokeai/**'
|
|
||||||
- '!invokeai/frontend/web/**'
|
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
|
||||||
- 'pyproject.toml'
|
|
||||||
- 'invokeai/**'
|
|
||||||
- 'tests/**'
|
|
||||||
- '!invokeai/frontend/web/**'
|
|
||||||
types:
|
types:
|
||||||
- 'ready_for_review'
|
- 'ready_for_review'
|
||||||
- 'opened'
|
- 'opened'
|
||||||
@ -65,10 +56,23 @@ jobs:
|
|||||||
id: checkout-sources
|
id: checkout-sources
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Check for changed python files
|
||||||
|
id: changed-files
|
||||||
|
uses: tj-actions/changed-files@v37
|
||||||
|
with:
|
||||||
|
files_yaml: |
|
||||||
|
python:
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'invokeai/**'
|
||||||
|
- '!invokeai/frontend/web/**'
|
||||||
|
- 'tests/**'
|
||||||
|
|
||||||
- name: set test prompt to main branch validation
|
- name: set test prompt to main branch validation
|
||||||
|
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
|
||||||
|
|
||||||
- name: setup python
|
- name: setup python
|
||||||
|
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
@ -76,6 +80,7 @@ jobs:
|
|||||||
cache-dependency-path: pyproject.toml
|
cache-dependency-path: pyproject.toml
|
||||||
|
|
||||||
- name: install invokeai
|
- name: install invokeai
|
||||||
|
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||||
env:
|
env:
|
||||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||||
run: >
|
run: >
|
||||||
@ -83,6 +88,7 @@ jobs:
|
|||||||
--editable=".[test]"
|
--editable=".[test]"
|
||||||
|
|
||||||
- name: run pytest
|
- name: run pytest
|
||||||
|
if: steps.changed-files.outputs.python_any_changed == 'true'
|
||||||
id: run-pytest
|
id: run-pytest
|
||||||
run: pytest
|
run: pytest
|
||||||
|
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
import os
|
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
SqliteBoardImageRecordStorage,
|
SqliteBoardImageRecordStorage,
|
||||||
)
|
)
|
||||||
@ -30,6 +29,7 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.model_manager_service import ModelManagerService
|
||||||
|
from ..services.invocation_stats import InvocationStatsService
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -128,6 +128,7 @@ class ApiDependencies:
|
|||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,24 +1,30 @@
|
|||||||
from fastapi import Body, HTTPException, Path, Query
|
from fastapi import Body, HTTPException
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
|
from pydantic import BaseModel, Field
|
||||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
|
||||||
from invokeai.app.services.models.image_record import ImageDTO
|
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
|
||||||
|
|
||||||
|
|
||||||
|
class AddImagesToBoardResult(BaseModel):
|
||||||
|
board_id: str = Field(description="The id of the board the images were added to")
|
||||||
|
added_image_names: list[str] = Field(description="The image names that were added to the board")
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveImagesFromBoardResult(BaseModel):
|
||||||
|
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
|
||||||
|
|
||||||
|
|
||||||
@board_images_router.post(
|
@board_images_router.post(
|
||||||
"/",
|
"/",
|
||||||
operation_id="create_board_image",
|
operation_id="add_image_to_board",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The image was added to a board successfully"},
|
201: {"description": "The image was added to a board successfully"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def create_board_image(
|
async def add_image_to_board(
|
||||||
board_id: str = Body(description="The id of the board to add to"),
|
board_id: str = Body(description="The id of the board to add to"),
|
||||||
image_name: str = Body(description="The name of the image to add"),
|
image_name: str = Body(description="The name of the image to add"),
|
||||||
):
|
):
|
||||||
@ -29,26 +35,78 @@ async def create_board_image(
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to add to board")
|
raise HTTPException(status_code=500, detail="Failed to add image to board")
|
||||||
|
|
||||||
|
|
||||||
@board_images_router.delete(
|
@board_images_router.delete(
|
||||||
"/",
|
"/",
|
||||||
operation_id="remove_board_image",
|
operation_id="remove_image_from_board",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The image was removed from the board successfully"},
|
201: {"description": "The image was removed from the board successfully"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def remove_board_image(
|
async def remove_image_from_board(
|
||||||
board_id: str = Body(description="The id of the board"),
|
image_name: str = Body(description="The name of the image to remove", embed=True),
|
||||||
image_name: str = Body(description="The name of the image to remove"),
|
|
||||||
):
|
):
|
||||||
"""Deletes a board_image"""
|
"""Removes an image from its board, if it had one"""
|
||||||
try:
|
try:
|
||||||
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(
|
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||||
board_id=board_id, image_name=image_name
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
raise HTTPException(status_code=500, detail="Failed to remove image from board")
|
||||||
|
|
||||||
|
|
||||||
|
@board_images_router.post(
|
||||||
|
"/batch",
|
||||||
|
operation_id="add_images_to_board",
|
||||||
|
responses={
|
||||||
|
201: {"description": "Images were added to board successfully"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=AddImagesToBoardResult,
|
||||||
|
)
|
||||||
|
async def add_images_to_board(
|
||||||
|
board_id: str = Body(description="The id of the board to add to"),
|
||||||
|
image_names: list[str] = Body(description="The names of the images to add", embed=True),
|
||||||
|
) -> AddImagesToBoardResult:
|
||||||
|
"""Adds a list of images to a board"""
|
||||||
|
try:
|
||||||
|
added_image_names: list[str] = []
|
||||||
|
for image_name in image_names:
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.board_images.add_image_to_board(
|
||||||
|
board_id=board_id, image_name=image_name
|
||||||
|
)
|
||||||
|
added_image_names.append(image_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to add images to board")
|
||||||
|
|
||||||
|
|
||||||
|
@board_images_router.post(
|
||||||
|
"/batch/delete",
|
||||||
|
operation_id="remove_images_from_board",
|
||||||
|
responses={
|
||||||
|
201: {"description": "Images were removed from board successfully"},
|
||||||
|
},
|
||||||
|
status_code=201,
|
||||||
|
response_model=RemoveImagesFromBoardResult,
|
||||||
|
)
|
||||||
|
async def remove_images_from_board(
|
||||||
|
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
|
||||||
|
) -> RemoveImagesFromBoardResult:
|
||||||
|
"""Removes a list of images from their board, if they had one"""
|
||||||
|
try:
|
||||||
|
removed_image_names: list[str] = []
|
||||||
|
for image_name in image_names:
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
|
||||||
|
removed_image_names.append(image_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to remove images from board")
|
||||||
|
@ -5,6 +5,7 @@ from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadF
|
|||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
@ -25,7 +26,7 @@ IMAGE_MAX_AGE = 31536000
|
|||||||
|
|
||||||
|
|
||||||
@images_router.post(
|
@images_router.post(
|
||||||
"/",
|
"/upload",
|
||||||
operation_id="upload_image",
|
operation_id="upload_image",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "The image was uploaded successfully"},
|
201: {"description": "The image was uploaded successfully"},
|
||||||
@ -77,7 +78,7 @@ async def upload_image(
|
|||||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||||
|
|
||||||
|
|
||||||
@images_router.delete("/{image_name}", operation_id="delete_image")
|
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||||
async def delete_image(
|
async def delete_image(
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
image_name: str = Path(description="The name of the image to delete"),
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -103,7 +104,7 @@ async def clear_intermediates() -> int:
|
|||||||
|
|
||||||
|
|
||||||
@images_router.patch(
|
@images_router.patch(
|
||||||
"/{image_name}",
|
"/i/{image_name}",
|
||||||
operation_id="update_image",
|
operation_id="update_image",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
@ -120,7 +121,7 @@ async def update_image(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}",
|
"/i/{image_name}",
|
||||||
operation_id="get_image_dto",
|
operation_id="get_image_dto",
|
||||||
response_model=ImageDTO,
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
@ -136,7 +137,7 @@ async def get_image_dto(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}/metadata",
|
"/i/{image_name}/metadata",
|
||||||
operation_id="get_image_metadata",
|
operation_id="get_image_metadata",
|
||||||
response_model=ImageMetadata,
|
response_model=ImageMetadata,
|
||||||
)
|
)
|
||||||
@ -152,7 +153,7 @@ async def get_image_metadata(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}/full",
|
"/i/{image_name}/full",
|
||||||
operation_id="get_image_full",
|
operation_id="get_image_full",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@ -187,7 +188,7 @@ async def get_image_full(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}/thumbnail",
|
"/i/{image_name}/thumbnail",
|
||||||
operation_id="get_image_thumbnail",
|
operation_id="get_image_thumbnail",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@ -216,7 +217,7 @@ async def get_image_thumbnail(
|
|||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/{image_name}/urls",
|
"/i/{image_name}/urls",
|
||||||
operation_id="get_image_urls",
|
operation_id="get_image_urls",
|
||||||
response_model=ImageUrlsDTO,
|
response_model=ImageUrlsDTO,
|
||||||
)
|
)
|
||||||
@ -265,3 +266,24 @@ async def list_image_dtos(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return image_dtos
|
return image_dtos
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteImagesFromListResult(BaseModel):
|
||||||
|
deleted_images: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
|
||||||
|
async def delete_images_from_list(
|
||||||
|
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
|
||||||
|
) -> DeleteImagesFromListResult:
|
||||||
|
try:
|
||||||
|
deleted_images: list[str] = []
|
||||||
|
for image_name in image_names:
|
||||||
|
try:
|
||||||
|
ApiDependencies.invoker.services.images.delete(image_name)
|
||||||
|
deleted_images.append(image_name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return DeleteImagesFromListResult(deleted_images=deleted_images)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete images")
|
||||||
|
@ -37,6 +37,7 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
@ -311,6 +312,7 @@ def invoke_cli():
|
|||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
)
|
)
|
||||||
|
@ -109,12 +109,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
|
(
|
||||||
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
context=context,
|
||||||
).context.model
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
@ -173,7 +176,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
@ -197,12 +200,15 @@ class SDXLPromptInvocationBase:
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
|
(
|
||||||
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
context=context,
|
||||||
).context.model
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
@ -210,8 +216,8 @@ class SDXLPromptInvocationBase:
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
with ModelPatcher.apply_lora(
|
||||||
text_encoder_info.context.model, _lora_loader()
|
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
@ -247,7 +253,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
return c, c_pooled, None
|
return c, c_pooled, None
|
||||||
|
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
@ -271,12 +277,15 @@ class SDXLPromptInvocationBase:
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
|
(
|
||||||
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
context=context,
|
||||||
).context.model
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
@ -284,8 +293,8 @@ class SDXLPromptInvocationBase:
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
with ModelPatcher.apply_lora(
|
||||||
text_encoder_info.context.model, _lora_loader()
|
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
@ -357,11 +366,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
|
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -415,7 +424,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -467,11 +477,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
|
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -525,7 +535,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
|
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
@ -14,7 +14,7 @@ from invokeai.app.invocations.metadata import CoreMetadata
|
|||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -10,16 +10,17 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModel):
|
class LoRAMetadataField(BaseModelExcludeNull):
|
||||||
"""LoRA metadata for an image generated in InvokeAI."""
|
"""LoRA metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
lora: LoRAModelField = Field(description="The LoRA model")
|
lora: LoRAModelField = Field(description="The LoRA model")
|
||||||
weight: float = Field(description="The weight of the LoRA model")
|
weight: float = Field(description="The weight of the LoRA model")
|
||||||
|
|
||||||
|
|
||||||
class CoreMetadata(BaseModel):
|
class CoreMetadata(BaseModelExcludeNull):
|
||||||
"""Core generation metadata for an image generated in InvokeAI."""
|
"""Core generation metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
generation_mode: str = Field(
|
generation_mode: str = Field(
|
||||||
@ -70,7 +71,7 @@ class CoreMetadata(BaseModel):
|
|||||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModel):
|
class ImageMetadata(BaseModelExcludeNull):
|
||||||
"""An image's generation metadata"""
|
"""An image's generation metadata"""
|
||||||
|
|
||||||
metadata: Optional[dict] = Field(
|
metadata: Optional[dict] = Field(
|
||||||
|
@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Model loader output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
|
||||||
|
|
||||||
|
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
|
||||||
|
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
|
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
|
||||||
|
|
||||||
|
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
||||||
|
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||||
|
|
||||||
|
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||||
|
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
|
||||||
|
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
|
||||||
|
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "SDXL Lora Loader",
|
||||||
|
"tags": ["lora", "loader"],
|
||||||
|
"type_hints": {"lora": "lora_model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||||
|
if self.lora is None:
|
||||||
|
raise Exception("No LoRA provided")
|
||||||
|
|
||||||
|
base_model = self.lora.base_model
|
||||||
|
lora_name = self.lora.model_name
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||||
|
|
||||||
|
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||||
|
|
||||||
|
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||||
|
|
||||||
|
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||||
|
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||||
|
|
||||||
|
output = SDXLLoraLoaderOutput()
|
||||||
|
|
||||||
|
if self.unet is not None:
|
||||||
|
output.unet = copy.deepcopy(self.unet)
|
||||||
|
output.unet.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.clip is not None:
|
||||||
|
output.clip = copy.deepcopy(self.clip)
|
||||||
|
output.clip.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.clip2 is not None:
|
||||||
|
output.clip2 = copy.deepcopy(self.clip2)
|
||||||
|
output.clip2.loras.append(
|
||||||
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
|
model_name=lora_name,
|
||||||
|
model_type=ModelType.Lora,
|
||||||
|
submodel=None,
|
||||||
|
weight=self.weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class VAEModelField(BaseModel):
|
class VAEModelField(BaseModel):
|
||||||
"""Vae model field"""
|
"""Vae model field"""
|
||||||
|
|
||||||
|
@ -65,7 +65,6 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||||
for lora in self.clip.loras
|
for lora in self.clip.loras
|
||||||
@ -76,18 +75,14 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
# stack.enter_context(
|
(
|
||||||
# context.services.model_manager.get_model(
|
name,
|
||||||
# model_name=name,
|
|
||||||
# base_model=self.clip.text_encoder.base_model,
|
|
||||||
# model_type=ModelType.TextualInversion,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
).context.model
|
).context.model,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# print(e)
|
# print(e)
|
||||||
|
@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import Field, validator
|
from pydantic import Field, validator
|
||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
|
|
||||||
@ -293,10 +293,20 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
num_inference_steps = self.steps
|
num_inference_steps = self.steps
|
||||||
|
|
||||||
|
def _lora_loader():
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||||
timesteps = scheduler.timesteps
|
timesteps = scheduler.timesteps
|
||||||
|
|
||||||
@ -543,9 +553,19 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _lora_loader():
|
||||||
|
for lora in self.unet.loras:
|
||||||
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
yield (lora_info.context.model, lora.weight)
|
||||||
|
del lora_info
|
||||||
|
return
|
||||||
|
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||||
# apply denoising_start
|
# apply denoising_start
|
||||||
num_inference_steps = self.steps
|
num_inference_steps = self.steps
|
||||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||||
|
@ -25,7 +25,6 @@ class BoardImageRecordStorageBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove_image_from_board(
|
def remove_image_from_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Removes an image from a board."""
|
"""Removes an image from a board."""
|
||||||
@ -154,7 +153,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
|
|
||||||
def remove_image_from_board(
|
def remove_image_from_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
@ -162,9 +160,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM board_images
|
DELETE FROM board_images
|
||||||
WHERE board_id = ? AND image_name = ?;
|
WHERE image_name = ?;
|
||||||
""",
|
""",
|
||||||
(board_id, image_name),
|
(image_name,),
|
||||||
)
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
|
@ -31,7 +31,6 @@ class BoardImagesServiceABC(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove_image_from_board(
|
def remove_image_from_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Removes an image from a board."""
|
"""Removes an image from a board."""
|
||||||
@ -93,10 +92,9 @@ class BoardImagesService(BoardImagesServiceABC):
|
|||||||
|
|
||||||
def remove_image_from_board(
|
def remove_image_from_board(
|
||||||
self,
|
self,
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
image_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._services.board_image_records.remove_image_from_board(board_id, image_name)
|
self._services.board_image_records.remove_image_from_board(image_name)
|
||||||
|
|
||||||
def get_all_board_image_names_for_board(
|
def get_all_board_image_names_for_board(
|
||||||
self,
|
self,
|
||||||
|
@ -289,9 +289,10 @@ class ImageService(ImageServiceABC):
|
|||||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||||
try:
|
try:
|
||||||
image_record = self._services.image_records.get(image_name)
|
image_record = self._services.image_records.get(image_name)
|
||||||
|
metadata = self._services.image_records.get_metadata(image_name)
|
||||||
|
|
||||||
if not image_record.session_id:
|
if not image_record.session_id:
|
||||||
return ImageMetadata()
|
return ImageMetadata(metadata=metadata)
|
||||||
|
|
||||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||||
graph = None
|
graph = None
|
||||||
@ -303,7 +304,6 @@ class ImageService(ImageServiceABC):
|
|||||||
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
graph = None
|
graph = None
|
||||||
|
|
||||||
metadata = self._services.image_records.get_metadata(image_name)
|
|
||||||
return ImageMetadata(graph=graph, metadata=metadata)
|
return ImageMetadata(graph=graph, metadata=metadata)
|
||||||
except ImageRecordNotFoundException:
|
except ImageRecordNotFoundException:
|
||||||
self._services.logger.error("Image record not found")
|
self._services.logger.error("Image record not found")
|
||||||
|
@ -32,6 +32,7 @@ class InvocationServices:
|
|||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
performance_statistics: "InvocationStatsServiceBase"
|
||||||
queue: "InvocationQueueABC"
|
queue: "InvocationQueueABC"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -47,6 +48,7 @@ class InvocationServices:
|
|||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
@ -61,4 +63,5 @@ class InvocationServices:
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
self.performance_statistics = performance_statistics
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
223
invokeai/app/services/invocation_stats.py
Normal file
223
invokeai/app/services/invocation_stats.py
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||||
|
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
statistics = InvocationStatsService(graph_execution_manager)
|
||||||
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
|
... execute graphs...
|
||||||
|
statistics.log_stats()
|
||||||
|
|
||||||
|
Typical output:
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||||
|
|
||||||
|
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||||
|
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
|
from .graph import GraphExecutionState
|
||||||
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationStatsServiceBase(ABC):
|
||||||
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
|
"""
|
||||||
|
Initialize the InvocationStatsService and reset counters to zero
|
||||||
|
:param graph_execution_manager: Graph execution manager for this session
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_stats(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> AbstractContextManager:
|
||||||
|
"""
|
||||||
|
Return a context object that will capture the statistics on the execution
|
||||||
|
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||||
|
:param invocation: BaseInvocation object from the current graph.
|
||||||
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_stats(self, graph_execution_state_id: str):
|
||||||
|
"""
|
||||||
|
Reset all statistics for the indicated graph
|
||||||
|
:param graph_execution_state_id
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_all_stats(self):
|
||||||
|
"""Zero all statistics"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_invocation_stats(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
invocation_type: str,
|
||||||
|
time_used: float,
|
||||||
|
vram_used: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add timing information on execution of a node. Usually
|
||||||
|
used internally.
|
||||||
|
:param graph_id: ID of the graph that is currently executing
|
||||||
|
:param invocation_type: String literal type of the node
|
||||||
|
:param time_used: Time used by node's exection (sec)
|
||||||
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def log_stats(self):
|
||||||
|
"""
|
||||||
|
Write out the accumulated statistics to the log or somewhere else.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStats:
|
||||||
|
"""Class for tracking execution stats of an invocation node"""
|
||||||
|
|
||||||
|
calls: int = 0
|
||||||
|
time_used: float = 0.0 # seconds
|
||||||
|
max_vram: float = 0.0 # GB
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeLog:
|
||||||
|
"""Class for tracking node usage"""
|
||||||
|
|
||||||
|
# {node_type => NodeStats}
|
||||||
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationStatsService(InvocationStatsServiceBase):
|
||||||
|
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||||
|
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||||
|
|
||||||
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
# {graph_id => NodeLog}
|
||||||
|
self._stats: Dict[str, NodeLog] = {}
|
||||||
|
|
||||||
|
class StatsContext:
|
||||||
|
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
|
||||||
|
self.invocation = invocation
|
||||||
|
self.collector = collector
|
||||||
|
self.graph_id = graph_id
|
||||||
|
self.start_time = 0
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.collector.update_invocation_stats(
|
||||||
|
self.graph_id,
|
||||||
|
self.invocation.type,
|
||||||
|
time.time() - self.start_time,
|
||||||
|
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def collect_stats(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> StatsContext:
|
||||||
|
"""
|
||||||
|
Return a context object that will capture the statistics.
|
||||||
|
:param invocation: BaseInvocation object from the current graph.
|
||||||
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||||
|
"""
|
||||||
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
|
return self.StatsContext(invocation, graph_execution_state_id, self)
|
||||||
|
|
||||||
|
def reset_all_stats(self):
|
||||||
|
"""Zero all statistics"""
|
||||||
|
self._stats = {}
|
||||||
|
|
||||||
|
def reset_stats(self, graph_execution_id: str):
|
||||||
|
"""Zero the statistics for the indicated graph."""
|
||||||
|
try:
|
||||||
|
self._stats.pop(graph_execution_id)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
|
||||||
|
|
||||||
|
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
|
||||||
|
"""
|
||||||
|
Add timing information on execution of a node. Usually
|
||||||
|
used internally.
|
||||||
|
:param graph_id: ID of the graph that is currently executing
|
||||||
|
:param invocation_type: String literal type of the node
|
||||||
|
:param time_used: Floating point seconds used by node's exection
|
||||||
|
"""
|
||||||
|
if not self._stats[graph_id].nodes.get(invocation_type):
|
||||||
|
self._stats[graph_id].nodes[invocation_type] = NodeStats()
|
||||||
|
stats = self._stats[graph_id].nodes[invocation_type]
|
||||||
|
stats.calls += 1
|
||||||
|
stats.time_used += time_used
|
||||||
|
stats.max_vram = max(stats.max_vram, vram_used)
|
||||||
|
|
||||||
|
def log_stats(self):
|
||||||
|
"""
|
||||||
|
Send the statistics to the system logger at the info level.
|
||||||
|
Stats will only be printed if when the execution of the graph
|
||||||
|
is complete.
|
||||||
|
"""
|
||||||
|
completed = set()
|
||||||
|
for graph_id, node_log in self._stats.items():
|
||||||
|
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||||
|
if not current_graph_state.is_complete():
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_time = 0
|
||||||
|
logger.info(f"Graph stats: {graph_id}")
|
||||||
|
logger.info("Node Calls Seconds VRAM Used")
|
||||||
|
for node_type, stats in self._stats[graph_id].nodes.items():
|
||||||
|
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
|
||||||
|
total_time += stats.time_used
|
||||||
|
|
||||||
|
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
|
||||||
|
|
||||||
|
completed.add(graph_id)
|
||||||
|
|
||||||
|
for graph_id in completed:
|
||||||
|
del self._stats[graph_id]
|
8
invokeai/app/services/models/board_image.py
Normal file
8
invokeai/app/services/models/board_image.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImage(BaseModelExcludeNull):
|
||||||
|
board_id: str = Field(description="The id of the board")
|
||||||
|
image_name: str = Field(description="The name of the image")
|
@ -1,10 +1,11 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
from pydantic import Field
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
class BoardRecord(BaseModel):
|
class BoardRecord(BaseModelExcludeNull):
|
||||||
"""Deserialized board record."""
|
"""Deserialized board record."""
|
||||||
|
|
||||||
board_id: str = Field(description="The unique ID of the board.")
|
board_id: str = Field(description="The unique ID of the board.")
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
from pydantic import Extra, Field, StrictBool, StrictStr
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
class ImageRecord(BaseModel):
|
class ImageRecord(BaseModelExcludeNull):
|
||||||
"""Deserialized image record without metadata."""
|
"""Deserialized image record without metadata."""
|
||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
@ -40,7 +41,7 @@ class ImageRecord(BaseModel):
|
|||||||
"""The node ID that generated this image, if it is a generated image."""
|
"""The node ID that generated this image, if it is a generated image."""
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
||||||
"""A set of changes to apply to an image record.
|
"""A set of changes to apply to an image record.
|
||||||
|
|
||||||
Only limited changes are valid:
|
Only limited changes are valid:
|
||||||
@ -60,7 +61,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
|||||||
"""The image's new `is_intermediate` flag."""
|
"""The image's new `is_intermediate` flag."""
|
||||||
|
|
||||||
|
|
||||||
class ImageUrlsDTO(BaseModel):
|
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||||
"""The URLs for an image and its thumbnail."""
|
"""The URLs for an image and its thumbnail."""
|
||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
@ -76,11 +77,15 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
|
|||||||
|
|
||||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||||
"""The id of the board the image belongs to, if one exists."""
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def image_record_to_dto(
|
def image_record_to_dto(
|
||||||
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
|
image_record: ImageRecord,
|
||||||
|
image_url: str,
|
||||||
|
thumbnail_url: str,
|
||||||
|
board_id: Optional[str],
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Converts an image record to an image DTO."""
|
"""Converts an image record to an image DTO."""
|
||||||
return ImageDTO(
|
return ImageDTO(
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from threading import Event, Thread, BoundedSemaphore
|
from threading import BoundedSemaphore, Event, Thread
|
||||||
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
|
||||||
from .invocation_queue import InvocationQueueItem
|
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
|
||||||
from ..models.exceptions import CanceledException
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
|
from ..models.exceptions import CanceledException
|
||||||
|
from .invocation_queue import InvocationQueueItem
|
||||||
|
from .invocation_stats import InvocationStatsServiceBase
|
||||||
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
|
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
@ -35,6 +36,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
def __process(self, stop_event: Event):
|
def __process(self, stop_event: Event):
|
||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
|
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||||
@ -83,6 +86,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
# Invoke
|
# Invoke
|
||||||
try:
|
try:
|
||||||
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
outputs = invocation.invoke(
|
outputs = invocation.invoke(
|
||||||
InvocationContext(
|
InvocationContext(
|
||||||
services=self.__invoker.services,
|
services=self.__invoker.services,
|
||||||
@ -107,11 +111,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
result=outputs.dict(),
|
result=outputs.dict(),
|
||||||
)
|
)
|
||||||
|
statistics.log_stats()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except CanceledException:
|
except CanceledException:
|
||||||
|
statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -133,7 +139,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
|
statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
|
@ -20,6 +20,6 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
|
|
||||||
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||||
if thumbnail:
|
if thumbnail:
|
||||||
return f"{self._base_url}/images/{image_basename}/thumbnail"
|
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||||
|
|
||||||
return f"{self._base_url}/images/{image_basename}/full"
|
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||||
|
23
invokeai/app/util/model_exclude_null.py
Normal file
23
invokeai/app/util/model_exclude_null.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Any
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
We want to exclude null values from objects that make their way to the client.
|
||||||
|
|
||||||
|
Unfortunately there is no built-in way to do this in pydantic, so we need to override the default
|
||||||
|
dict method to do this.
|
||||||
|
|
||||||
|
From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154541
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelExcludeNull(BaseModel):
|
||||||
|
def dict(self, *args, **kwargs) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Override the default dict method to exclude None values in the response
|
||||||
|
"""
|
||||||
|
kwargs.pop("exclude_none", None)
|
||||||
|
return super().dict(*args, exclude_none=True, **kwargs)
|
||||||
|
|
||||||
|
pass
|
@ -305,7 +305,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
staging = Path(staging)
|
staging = Path(staging)
|
||||||
if "model_index.json" in files and "unet/model.onnx" not in files:
|
if "model_index.json" in files:
|
||||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
elif "unet/model.onnx" in files:
|
elif "unet/model.onnx" in files:
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
|
@ -13,3 +13,4 @@ from .models import (
|
|||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
)
|
)
|
||||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
from .lora import ModelPatcher
|
||||||
|
@ -20,424 +20,6 @@ from diffusers.models import UNet2DConditionModel
|
|||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
# TODO: rename and split this file
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
|
||||||
# rank: Optional[int]
|
|
||||||
# alpha: Optional[float]
|
|
||||||
# bias: Optional[torch.Tensor]
|
|
||||||
# layer_key: str
|
|
||||||
|
|
||||||
# @property
|
|
||||||
# def scale(self):
|
|
||||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
if "alpha" in values:
|
|
||||||
self.alpha = values["alpha"].item()
|
|
||||||
else:
|
|
||||||
self.alpha = None
|
|
||||||
|
|
||||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
|
||||||
self.bias = torch.sparse_coo_tensor(
|
|
||||||
values["bias_indices"],
|
|
||||||
values["bias_values"],
|
|
||||||
tuple(values["bias_size"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.rank = None # set in layer implementation
|
|
||||||
self.layer_key = layer_key
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
module: torch.nn.Module,
|
|
||||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
|
||||||
multiplier: float,
|
|
||||||
):
|
|
||||||
if type(module) == torch.nn.Conv2d:
|
|
||||||
op = torch.nn.functional.conv2d
|
|
||||||
extra_args = dict(
|
|
||||||
stride=module.stride,
|
|
||||||
padding=module.padding,
|
|
||||||
dilation=module.dilation,
|
|
||||||
groups=module.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
op = torch.nn.functional.linear
|
|
||||||
extra_args = {}
|
|
||||||
|
|
||||||
weight = self.get_weight()
|
|
||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
|
||||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
||||||
return (
|
|
||||||
op(
|
|
||||||
*input_h,
|
|
||||||
(weight + bias).view(module.weight.shape),
|
|
||||||
None,
|
|
||||||
**extra_args,
|
|
||||||
)
|
|
||||||
* multiplier
|
|
||||||
* scale
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for val in [self.bias]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
|
||||||
class LoRALayer(LoRALayerBase):
|
|
||||||
# up: torch.Tensor
|
|
||||||
# mid: Optional[torch.Tensor]
|
|
||||||
# down: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
|
||||||
self.down = values["lora_down.weight"]
|
|
||||||
if "lora_mid.weight" in values:
|
|
||||||
self.mid = values["lora_mid.weight"]
|
|
||||||
else:
|
|
||||||
self.mid = None
|
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.mid is not None:
|
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
|
||||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
|
||||||
else:
|
|
||||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.up, self.mid, self.down]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.up = self.up.to(device=device, dtype=dtype)
|
|
||||||
self.down = self.down.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.mid is not None:
|
|
||||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
|
||||||
# w1_a: torch.Tensor
|
|
||||||
# w1_b: torch.Tensor
|
|
||||||
# w2_a: torch.Tensor
|
|
||||||
# w2_b: torch.Tensor
|
|
||||||
# t1: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.w1_a = values["hada_w1_a"]
|
|
||||||
self.w1_b = values["hada_w1_b"]
|
|
||||||
self.w2_a = values["hada_w2_a"]
|
|
||||||
self.w2_b = values["hada_w2_b"]
|
|
||||||
|
|
||||||
if "hada_t1" in values:
|
|
||||||
self.t1 = values["hada_t1"]
|
|
||||||
else:
|
|
||||||
self.t1 = None
|
|
||||||
|
|
||||||
if "hada_t2" in values:
|
|
||||||
self.t2 = values["hada_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
if self.t1 is None:
|
|
||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
|
||||||
|
|
||||||
else:
|
|
||||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
|
||||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
|
||||||
weight = rebuild1 * rebuild2
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t1 is not None:
|
|
||||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
|
||||||
# w1: Optional[torch.Tensor] = None
|
|
||||||
# w1_a: Optional[torch.Tensor] = None
|
|
||||||
# w1_b: Optional[torch.Tensor] = None
|
|
||||||
# w2: Optional[torch.Tensor] = None
|
|
||||||
# w2_a: Optional[torch.Tensor] = None
|
|
||||||
# w2_b: Optional[torch.Tensor] = None
|
|
||||||
# t2: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: dict,
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
|
||||||
self.w1 = values["lokr_w1"]
|
|
||||||
self.w1_a = None
|
|
||||||
self.w1_b = None
|
|
||||||
else:
|
|
||||||
self.w1 = None
|
|
||||||
self.w1_a = values["lokr_w1_a"]
|
|
||||||
self.w1_b = values["lokr_w1_b"]
|
|
||||||
|
|
||||||
if "lokr_w2" in values:
|
|
||||||
self.w2 = values["lokr_w2"]
|
|
||||||
self.w2_a = None
|
|
||||||
self.w2_b = None
|
|
||||||
else:
|
|
||||||
self.w2 = None
|
|
||||||
self.w2_a = values["lokr_w2_a"]
|
|
||||||
self.w2_b = values["lokr_w2_b"]
|
|
||||||
|
|
||||||
if "lokr_t2" in values:
|
|
||||||
self.t2 = values["lokr_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
if "lokr_w1_b" in values:
|
|
||||||
self.rank = values["lokr_w1_b"].shape[0]
|
|
||||||
elif "lokr_w2_b" in values:
|
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
|
||||||
else:
|
|
||||||
self.rank = None # unscaled
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
w1 = self.w1
|
|
||||||
if w1 is None:
|
|
||||||
w1 = self.w1_a @ self.w1_b
|
|
||||||
|
|
||||||
w2 = self.w2
|
|
||||||
if w2 is None:
|
|
||||||
if self.t2 is None:
|
|
||||||
w2 = self.w2_a @ self.w2_b
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
w2 = w2.contiguous()
|
|
||||||
weight = torch.kron(w1, w2)
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
|
||||||
if val is not None:
|
|
||||||
model_size += val.nelement() * val.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w1 is not None:
|
|
||||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.w2 is not None:
|
|
||||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
if self.t2 is not None:
|
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel: # (torch.nn.Module):
|
|
||||||
_name: str
|
|
||||||
layers: Dict[str, LoRALayer]
|
|
||||||
_device: torch.device
|
|
||||||
_dtype: torch.dtype
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
layers: Dict[str, LoRALayer],
|
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
|
||||||
self._name = name
|
|
||||||
self._device = device or torch.cpu
|
|
||||||
self._dtype = dtype or torch.float32
|
|
||||||
self.layers = layers
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self._device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self._dtype
|
|
||||||
|
|
||||||
def to(
|
|
||||||
self,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
) -> LoRAModel:
|
|
||||||
# TODO: try revert if exception?
|
|
||||||
for key, layer in self.layers.items():
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
self._device = device
|
|
||||||
self._dtype = dtype
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = 0
|
|
||||||
for _, layer in self.layers.items():
|
|
||||||
model_size += layer.calc_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_checkpoint(
|
|
||||||
cls,
|
|
||||||
file_path: Union[str, Path],
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
):
|
|
||||||
device = device or torch.device("cpu")
|
|
||||||
dtype = dtype or torch.float32
|
|
||||||
|
|
||||||
if isinstance(file_path, str):
|
|
||||||
file_path = Path(file_path)
|
|
||||||
|
|
||||||
model = cls(
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
name=file_path.stem, # TODO:
|
|
||||||
layers=dict(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
|
||||||
else:
|
|
||||||
state_dict = torch.load(file_path, map_location="cpu")
|
|
||||||
|
|
||||||
state_dict = cls._group_state(state_dict)
|
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
|
||||||
# lora and locon
|
|
||||||
if "lora_down.weight" in values:
|
|
||||||
layer = LoRALayer(layer_key, values)
|
|
||||||
|
|
||||||
# loha
|
|
||||||
elif "hada_w1_b" in values:
|
|
||||||
layer = LoHALayer(layer_key, values)
|
|
||||||
|
|
||||||
# lokr
|
|
||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
|
||||||
layer = LoKRLayer(layer_key, values)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# TODO: diff/ia3/... format
|
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
|
||||||
state_dict[layer_key].clear()
|
|
||||||
|
|
||||||
layer.to(device=device, dtype=dtype)
|
|
||||||
model.layers[layer_key] = layer
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _group_state(state_dict: dict):
|
|
||||||
state_dict_groupped = dict()
|
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
stem, leaf = key.split(".", 1)
|
|
||||||
if stem not in state_dict_groupped:
|
|
||||||
state_dict_groupped[stem] = dict()
|
|
||||||
state_dict_groupped[stem][leaf] = value
|
|
||||||
|
|
||||||
return state_dict_groupped
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
(lora_model1, 0.7),
|
(lora_model1, 0.7),
|
||||||
@ -516,6 +98,26 @@ class ModelPatcher:
|
|||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_sdxl_lora_text_encoder(
|
||||||
|
cls,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_sdxl_lora_text_encoder2(
|
||||||
|
cls,
|
||||||
|
text_encoder: CLIPTextModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||||
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora(
|
def apply_lora(
|
||||||
@ -562,7 +164,7 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
ti_list: List[Any],
|
ti_list: List[Tuple[str, Any]],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
init_tokens_count = None
|
init_tokens_count = None
|
||||||
new_tokens_added = None
|
new_tokens_added = None
|
||||||
@ -572,27 +174,27 @@ class ModelPatcher:
|
|||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
def _get_trigger(ti_name, index):
|
||||||
trigger = ti.name
|
trigger = ti_name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||||
model_embeddings = text_encoder.get_input_embeddings()
|
model_embeddings = text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i]
|
embedding = ti.embedding[i]
|
||||||
trigger = _get_trigger(ti, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
@ -637,7 +239,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
name: str
|
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -651,7 +252,6 @@ class TextualInversionModel:
|
|||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
result = cls() # TODO:
|
result = cls() # TODO:
|
||||||
result.name = file_path.stem # TODO:
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
@ -828,7 +428,7 @@ class ONNXModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: IAIOnnxRuntimeModel,
|
text_encoder: IAIOnnxRuntimeModel,
|
||||||
ti_list: List[Any],
|
ti_list: List[Tuple[str, Any]],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
@ -841,17 +441,17 @@ class ONNXModelPatcher:
|
|||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
def _get_trigger(ti_name, index):
|
||||||
trigger = ti.name
|
trigger = ti_name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||||
@ -861,11 +461,11 @@ class ONNXModelPatcher:
|
|||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
for ti in ti_list:
|
for ti_name, ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i].detach().numpy()
|
embedding = ti.embedding[i].detach().numpy()
|
||||||
trigger = _get_trigger(ti, i)
|
trigger = _get_trigger(ti_name, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
|
@ -28,8 +28,6 @@ import torch
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import get_invokeai_config
|
|
||||||
from .lora import LoRAModel, TextualInversionModel
|
|
||||||
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
@ -188,7 +186,7 @@ class ModelCache(object):
|
|||||||
cache_entry = self._cached_models.get(key, None)
|
cache_entry = self._cached_models.get(key, None)
|
||||||
if cache_entry is None:
|
if cache_entry is None:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
|
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
|
@ -472,7 +472,7 @@ class ModelManager(object):
|
|||||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||||
override_path = getattr(model_config, submodel_type)
|
override_path = getattr(model_config, submodel_type)
|
||||||
if override_path:
|
if override_path:
|
||||||
model_path = self.app_config.root_path / override_path
|
model_path = self.resolve_path(override_path)
|
||||||
model_type = submodel_type
|
model_type = submodel_type
|
||||||
submodel_type = None
|
submodel_type = None
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
@ -670,7 +670,7 @@ class ModelManager(object):
|
|||||||
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
||||||
|
|
||||||
# remove conversion cache as config changed
|
# remove conversion cache as config changed
|
||||||
old_model_path = self.app_config.root_path / old_model.path
|
old_model_path = self.resolve_model_path(old_model.path)
|
||||||
old_model_cache = self._get_model_cache_path(old_model_path)
|
old_model_cache = self._get_model_cache_path(old_model_path)
|
||||||
if old_model_cache.exists():
|
if old_model_cache.exists():
|
||||||
if old_model_cache.is_dir():
|
if old_model_cache.is_dir():
|
||||||
@ -780,7 +780,7 @@ class ModelManager(object):
|
|||||||
model_type,
|
model_type,
|
||||||
**submodel,
|
**submodel,
|
||||||
)
|
)
|
||||||
checkpoint_path = self.app_config.root_path / info["path"]
|
checkpoint_path = self.resolve_model_path(info["path"])
|
||||||
old_diffusers_path = self.resolve_model_path(model.location)
|
old_diffusers_path = self.resolve_model_path(model.location)
|
||||||
new_diffusers_path = (
|
new_diffusers_path = (
|
||||||
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
||||||
@ -992,7 +992,7 @@ class ModelManager(object):
|
|||||||
model_manager=self,
|
model_manager=self,
|
||||||
prediction_type_helper=ask_user_for_prediction_type,
|
prediction_type_helper=ask_user_for_prediction_type,
|
||||||
)
|
)
|
||||||
known_paths = {config.root_path / x["path"] for x in self.list_models()}
|
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
|
||||||
directories = {
|
directories = {
|
||||||
config.root_path / x
|
config.root_path / x
|
||||||
for x in [
|
for x in [
|
||||||
|
@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
|
|
||||||
|
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
|
||||||
|
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
|
||||||
|
# misclassified as SD-1
|
||||||
|
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
|
if key in checkpoint and checkpoint[key].shape[0] == 320:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
|
||||||
|
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
|
||||||
|
if key in checkpoint:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
|
||||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||||
|
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||||
|
|
||||||
lora_token_vector_length = (
|
lora_token_vector_length = (
|
||||||
checkpoint[key1].shape[1]
|
checkpoint[key1].shape[1]
|
||||||
if key1 in checkpoint
|
if key1 in checkpoint
|
||||||
else checkpoint[key2].shape[0]
|
else checkpoint[key2].shape[1]
|
||||||
if key2 in checkpoint
|
if key2 in checkpoint
|
||||||
else 768
|
else checkpoint[key3].shape[0]
|
||||||
|
if key3 in checkpoint
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if lora_token_vector_length == 768:
|
if lora_token_vector_length == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif lora_token_vector_length == 1024:
|
elif lora_token_vector_length == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
else:
|
else:
|
||||||
return None
|
raise InvalidModelException(f"Unknown LoRA type")
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
@ -292,8 +292,9 @@ class DiffusersModel(ModelBase):
|
|||||||
)
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# print("====ERR LOAD====")
|
if not str(e).startswith("Error no file"):
|
||||||
# print(f"{variant}: {e}")
|
print("====ERR LOAD====")
|
||||||
|
print(f"{variant}: {e}")
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional, Dict, Union, Literal, Any
|
||||||
|
from pathlib import Path
|
||||||
|
from safetensors.torch import load_file
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
@ -13,9 +15,6 @@ from .base import (
|
|||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: naming
|
|
||||||
from ..lora import LoRAModel as LoRAModelRaw
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelFormat(str, Enum):
|
class LoRAModelFormat(str, Enum):
|
||||||
LyCORIS = "lycoris"
|
LyCORIS = "lycoris"
|
||||||
@ -50,6 +49,7 @@ class LoRAModel(ModelBase):
|
|||||||
model = LoRAModelRaw.from_checkpoint(
|
model = LoRAModelRaw.from_checkpoint(
|
||||||
file_path=self.model_path,
|
file_path=self.model_path,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
|
base_model=self.base_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_size = model.calc_size()
|
self.model_size = model.calc_size()
|
||||||
@ -87,3 +87,582 @@ class LoRAModel(ModelBase):
|
|||||||
raise NotImplementedError("Diffusers lora not supported")
|
raise NotImplementedError("Diffusers lora not supported")
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALayerBase:
|
||||||
|
# rank: Optional[int]
|
||||||
|
# alpha: Optional[float]
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
# layer_key: str
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def scale(self):
|
||||||
|
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
if "alpha" in values:
|
||||||
|
self.alpha = values["alpha"].item()
|
||||||
|
else:
|
||||||
|
self.alpha = None
|
||||||
|
|
||||||
|
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||||
|
self.bias = torch.sparse_coo_tensor(
|
||||||
|
values["bias_indices"],
|
||||||
|
values["bias_values"],
|
||||||
|
tuple(values["bias_size"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
self.rank = None # set in layer implementation
|
||||||
|
self.layer_key = layer_key
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||||
|
multiplier: float,
|
||||||
|
):
|
||||||
|
if type(module) == torch.nn.Conv2d:
|
||||||
|
op = torch.nn.functional.conv2d
|
||||||
|
extra_args = dict(
|
||||||
|
stride=module.stride,
|
||||||
|
padding=module.padding,
|
||||||
|
dilation=module.dilation,
|
||||||
|
groups=module.groups,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
op = torch.nn.functional.linear
|
||||||
|
extra_args = {}
|
||||||
|
|
||||||
|
weight = self.get_weight()
|
||||||
|
|
||||||
|
bias = self.bias if self.bias is not None else 0
|
||||||
|
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
return (
|
||||||
|
op(
|
||||||
|
*input_h,
|
||||||
|
(weight + bias).view(module.weight.shape),
|
||||||
|
None,
|
||||||
|
**extra_args,
|
||||||
|
)
|
||||||
|
* multiplier
|
||||||
|
* scale
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for val in [self.bias]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: find and debug lora/locon with bias
|
||||||
|
class LoRALayer(LoRALayerBase):
|
||||||
|
# up: torch.Tensor
|
||||||
|
# mid: Optional[torch.Tensor]
|
||||||
|
# down: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.up = values["lora_up.weight"]
|
||||||
|
self.down = values["lora_down.weight"]
|
||||||
|
if "lora_mid.weight" in values:
|
||||||
|
self.mid = values["lora_mid.weight"]
|
||||||
|
else:
|
||||||
|
self.mid = None
|
||||||
|
|
||||||
|
self.rank = self.down.shape[0]
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
if self.mid is not None:
|
||||||
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
|
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||||
|
else:
|
||||||
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.up, self.mid, self.down]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.up = self.up.to(device=device, dtype=dtype)
|
||||||
|
self.down = self.down.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.mid is not None:
|
||||||
|
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LoHALayer(LoRALayerBase):
|
||||||
|
# w1_a: torch.Tensor
|
||||||
|
# w1_b: torch.Tensor
|
||||||
|
# w2_a: torch.Tensor
|
||||||
|
# w2_b: torch.Tensor
|
||||||
|
# t1: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.w1_a = values["hada_w1_a"]
|
||||||
|
self.w1_b = values["hada_w1_b"]
|
||||||
|
self.w2_a = values["hada_w2_a"]
|
||||||
|
self.w2_b = values["hada_w2_b"]
|
||||||
|
|
||||||
|
if "hada_t1" in values:
|
||||||
|
self.t1 = values["hada_t1"]
|
||||||
|
else:
|
||||||
|
self.t1 = None
|
||||||
|
|
||||||
|
if "hada_t2" in values:
|
||||||
|
self.t2 = values["hada_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
|
self.rank = self.w1_b.shape[0]
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
if self.t1 is None:
|
||||||
|
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
|
else:
|
||||||
|
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||||
|
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||||
|
weight = rebuild1 * rebuild2
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t1 is not None:
|
||||||
|
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class LoKRLayer(LoRALayerBase):
|
||||||
|
# w1: Optional[torch.Tensor] = None
|
||||||
|
# w1_a: Optional[torch.Tensor] = None
|
||||||
|
# w1_b: Optional[torch.Tensor] = None
|
||||||
|
# w2: Optional[torch.Tensor] = None
|
||||||
|
# w2_a: Optional[torch.Tensor] = None
|
||||||
|
# w2_b: Optional[torch.Tensor] = None
|
||||||
|
# t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
if "lokr_w1" in values:
|
||||||
|
self.w1 = values["lokr_w1"]
|
||||||
|
self.w1_a = None
|
||||||
|
self.w1_b = None
|
||||||
|
else:
|
||||||
|
self.w1 = None
|
||||||
|
self.w1_a = values["lokr_w1_a"]
|
||||||
|
self.w1_b = values["lokr_w1_b"]
|
||||||
|
|
||||||
|
if "lokr_w2" in values:
|
||||||
|
self.w2 = values["lokr_w2"]
|
||||||
|
self.w2_a = None
|
||||||
|
self.w2_b = None
|
||||||
|
else:
|
||||||
|
self.w2 = None
|
||||||
|
self.w2_a = values["lokr_w2_a"]
|
||||||
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
|
||||||
|
if "lokr_t2" in values:
|
||||||
|
self.t2 = values["lokr_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
|
if "lokr_w1_b" in values:
|
||||||
|
self.rank = values["lokr_w1_b"].shape[0]
|
||||||
|
elif "lokr_w2_b" in values:
|
||||||
|
self.rank = values["lokr_w2_b"].shape[0]
|
||||||
|
else:
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
w1 = self.w1
|
||||||
|
if w1 is None:
|
||||||
|
w1 = self.w1_a @ self.w1_b
|
||||||
|
|
||||||
|
w2 = self.w2
|
||||||
|
if w2 is None:
|
||||||
|
if self.t2 is None:
|
||||||
|
w2 = self.w2_a @ self.w2_b
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
w2 = w2.contiguous()
|
||||||
|
weight = torch.kron(w1, w2)
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||||
|
if val is not None:
|
||||||
|
model_size += val.nelement() * val.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w1 is not None:
|
||||||
|
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||||
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.w2 is not None:
|
||||||
|
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||||
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.t2 is not None:
|
||||||
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class FullLayer(LoRALayerBase):
|
||||||
|
# weight: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["diff"]
|
||||||
|
|
||||||
|
if len(values.keys()) > 1:
|
||||||
|
_keys = list(values.keys())
|
||||||
|
_keys.remove("diff")
|
||||||
|
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||||
|
class LoRAModelRaw: # (torch.nn.Module):
|
||||||
|
_name: str
|
||||||
|
layers: Dict[str, LoRALayer]
|
||||||
|
_device: torch.device
|
||||||
|
_dtype: torch.dtype
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
layers: Dict[str, LoRALayer],
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
self._name = name
|
||||||
|
self._device = device or torch.cpu
|
||||||
|
self._dtype = dtype or torch.float32
|
||||||
|
self.layers = layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self._dtype
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
# TODO: try revert if exception?
|
||||||
|
for key, layer in self.layers.items():
|
||||||
|
layer.to(device=device, dtype=dtype)
|
||||||
|
self._device = device
|
||||||
|
self._dtype = dtype
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = 0
|
||||||
|
for _, layer in self.layers.items():
|
||||||
|
model_size += layer.calc_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _convert_sdxl_compvis_keys(cls, state_dict):
|
||||||
|
new_state_dict = dict()
|
||||||
|
for full_key, value in state_dict.items():
|
||||||
|
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||||
|
continue # clip same
|
||||||
|
|
||||||
|
if not full_key.startswith("lora_unet_"):
|
||||||
|
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
|
||||||
|
src_key = full_key.replace("lora_unet_", "")
|
||||||
|
try:
|
||||||
|
dst_key = None
|
||||||
|
while "_" in src_key:
|
||||||
|
if src_key in SDXL_UNET_COMPVIS_MAP:
|
||||||
|
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
|
||||||
|
break
|
||||||
|
src_key = "_".join(src_key.split("_")[:-1])
|
||||||
|
|
||||||
|
if dst_key is None:
|
||||||
|
raise Exception(f"Unknown sdxl lora key - {full_key}")
|
||||||
|
new_key = full_key.replace(src_key, dst_key)
|
||||||
|
except:
|
||||||
|
print(SDXL_UNET_COMPVIS_MAP)
|
||||||
|
raise
|
||||||
|
new_state_dict[new_key] = value
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
):
|
||||||
|
device = device or torch.device("cpu")
|
||||||
|
dtype = dtype or torch.float32
|
||||||
|
|
||||||
|
if isinstance(file_path, str):
|
||||||
|
file_path = Path(file_path)
|
||||||
|
|
||||||
|
model = cls(
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
name=file_path.stem, # TODO:
|
||||||
|
layers=dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_path.suffix == ".safetensors":
|
||||||
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(file_path, map_location="cpu")
|
||||||
|
|
||||||
|
state_dict = cls._group_state(state_dict)
|
||||||
|
|
||||||
|
if base_model == BaseModelType.StableDiffusionXL:
|
||||||
|
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
|
||||||
|
|
||||||
|
for layer_key, values in state_dict.items():
|
||||||
|
# lora and locon
|
||||||
|
if "lora_down.weight" in values:
|
||||||
|
layer = LoRALayer(layer_key, values)
|
||||||
|
|
||||||
|
# loha
|
||||||
|
elif "hada_w1_b" in values:
|
||||||
|
layer = LoHALayer(layer_key, values)
|
||||||
|
|
||||||
|
# lokr
|
||||||
|
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||||
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
|
elif "diff" in values:
|
||||||
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# TODO: ia3/... format
|
||||||
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
|
raise Exception("Unknown lora format!")
|
||||||
|
|
||||||
|
# lower memory consumption by removing already parsed layer values
|
||||||
|
state_dict[layer_key].clear()
|
||||||
|
|
||||||
|
layer.to(device=device, dtype=dtype)
|
||||||
|
model.layers[layer_key] = layer
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _group_state(state_dict: dict):
|
||||||
|
state_dict_groupped = dict()
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
stem, leaf = key.split(".", 1)
|
||||||
|
if stem not in state_dict_groupped:
|
||||||
|
state_dict_groupped[stem] = dict()
|
||||||
|
state_dict_groupped[stem][leaf] = value
|
||||||
|
|
||||||
|
return state_dict_groupped
|
||||||
|
|
||||||
|
|
||||||
|
# code from
|
||||||
|
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||||
|
def make_sdxl_unet_conversion_map():
|
||||||
|
unet_conversion_map_layer = []
|
||||||
|
|
||||||
|
for i in range(3): # num_blocks is 3 in sdxl
|
||||||
|
# loop over downblocks/upblocks
|
||||||
|
for j in range(2):
|
||||||
|
# loop over resnets/attentions for downblocks
|
||||||
|
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||||
|
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no attention layers in down_blocks.3
|
||||||
|
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||||
|
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(3):
|
||||||
|
# loop over resnets/attentions for upblocks
|
||||||
|
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||||
|
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||||
|
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||||
|
|
||||||
|
# if i > 0: commentout for sdxl
|
||||||
|
# no attention layers in up_blocks.0
|
||||||
|
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||||
|
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||||
|
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||||
|
|
||||||
|
if i < 3:
|
||||||
|
# no downsample in down_blocks.3
|
||||||
|
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||||
|
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||||
|
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||||
|
|
||||||
|
# no upsample in up_blocks.3
|
||||||
|
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||||
|
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||||
|
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||||
|
|
||||||
|
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||||
|
sd_mid_atn_prefix = "middle_block.1."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||||
|
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||||
|
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map_resnet = [
|
||||||
|
# (stable-diffusion, HF Diffusers)
|
||||||
|
("in_layers.0.", "norm1."),
|
||||||
|
("in_layers.2.", "conv1."),
|
||||||
|
("out_layers.0.", "norm2."),
|
||||||
|
("out_layers.3.", "conv2."),
|
||||||
|
("emb_layers.1.", "time_emb_proj."),
|
||||||
|
("skip_connection.", "conv_shortcut."),
|
||||||
|
]
|
||||||
|
|
||||||
|
unet_conversion_map = []
|
||||||
|
for sd, hf in unet_conversion_map_layer:
|
||||||
|
if "resnets" in hf:
|
||||||
|
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||||
|
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||||
|
else:
|
||||||
|
unet_conversion_map.append((sd, hf))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||||
|
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||||
|
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||||
|
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||||
|
|
||||||
|
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||||
|
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||||
|
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||||
|
|
||||||
|
return unet_conversion_map
|
||||||
|
|
||||||
|
|
||||||
|
SDXL_UNET_COMPVIS_MAP = {
|
||||||
|
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||||
|
for sd, hf in make_sdxl_unet_conversion_map()
|
||||||
|
}
|
||||||
|
@ -4,6 +4,7 @@ from enum import Enum
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -263,6 +264,8 @@ def _convert_ckpt_and_cache(
|
|||||||
weights = app_config.models_path / model_config.path
|
weights = app_config.models_path / model_config.path
|
||||||
config_file = app_config.root_path / model_config.config
|
config_file = app_config.root_path / model_config.config
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
variant = model_config.variant
|
||||||
|
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
|
||||||
|
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
if output_path.exists():
|
if output_path.exists():
|
||||||
@ -289,6 +292,7 @@ def _convert_ckpt_and_cache(
|
|||||||
original_config_file=config_file,
|
original_config_file=config_file,
|
||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
scan_needed=True,
|
scan_needed=True,
|
||||||
|
pipeline_class=pipeline_class,
|
||||||
from_safetensors=weights.suffix == ".safetensors",
|
from_safetensors=weights.suffix == ".safetensors",
|
||||||
precision=torch_dtype(choose_torch_device()),
|
precision=torch_dtype(choose_torch_device()),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -78,10 +78,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
self.sequential_guidance = config.sequential_guidance
|
self.sequential_guidance = config.sequential_guidance
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(
|
def custom_attention_context(
|
||||||
cls,
|
self,
|
||||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
step_count: int,
|
step_count: int,
|
||||||
@ -91,18 +90,19 @@ class InvokeAIDiffuserComponent:
|
|||||||
old_attn_processors = unet.attn_processors
|
old_attn_processors = unet.attn_processors
|
||||||
# Load lora conditions into the model
|
# Load lora conditions into the model
|
||||||
if extra_conditioning_info.wants_cross_attention_control:
|
if extra_conditioning_info.wants_cross_attention_control:
|
||||||
cross_attention_control_context = Context(
|
self.cross_attention_control_context = Context(
|
||||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||||
step_count=step_count,
|
step_count=step_count,
|
||||||
)
|
)
|
||||||
setup_cross_attention_control_attention_processors(
|
setup_cross_attention_control_attention_processors(
|
||||||
unet,
|
unet,
|
||||||
cross_attention_control_context,
|
self.cross_attention_control_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
|
self.cross_attention_control_context = None
|
||||||
if old_attn_processors is not None:
|
if old_attn_processors is not None:
|
||||||
unet.set_attn_processor(old_attn_processors)
|
unet.set_attn_processor(old_attn_processors)
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
|
@ -23,7 +23,7 @@
|
|||||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||||
"build": "yarn run lint && vite build",
|
"build": "yarn run lint && vite build",
|
||||||
"typegen": "npx ts-node scripts/typegen.ts",
|
"typegen": "node scripts/typegen.js",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"lint:madge": "madge --circular src/main.tsx",
|
"lint:madge": "madge --circular src/main.tsx",
|
||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
|
@ -124,7 +124,8 @@
|
|||||||
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
|
||||||
"deleteImagePermanent": "Deleted images cannot be restored.",
|
"deleteImagePermanent": "Deleted images cannot be restored.",
|
||||||
"images": "Images",
|
"images": "Images",
|
||||||
"assets": "Assets"
|
"assets": "Assets",
|
||||||
|
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
|
||||||
},
|
},
|
||||||
"hotkeys": {
|
"hotkeys": {
|
||||||
"keyboardShortcuts": "Keyboard Shortcuts",
|
"keyboardShortcuts": "Keyboard Shortcuts",
|
||||||
|
@ -4,8 +4,9 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
import { PartialAppConfig } from 'app/types/invokeai';
|
||||||
import ImageUploader from 'common/components/ImageUploader';
|
import ImageUploader from 'common/components/ImageUploader';
|
||||||
|
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
|
||||||
|
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||||
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
|
||||||
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
|
|
||||||
import SiteHeader from 'features/system/components/SiteHeader';
|
import SiteHeader from 'features/system/components/SiteHeader';
|
||||||
import { configChanged } from 'features/system/store/configSlice';
|
import { configChanged } from 'features/system/store/configSlice';
|
||||||
import { languageSelector } from 'features/system/store/systemSelectors';
|
import { languageSelector } from 'features/system/store/systemSelectors';
|
||||||
@ -16,7 +17,6 @@ import ParametersDrawer from 'features/ui/components/ParametersDrawer';
|
|||||||
import i18n from 'i18n';
|
import i18n from 'i18n';
|
||||||
import { size } from 'lodash-es';
|
import { size } from 'lodash-es';
|
||||||
import { ReactNode, memo, useEffect } from 'react';
|
import { ReactNode, memo, useEffect } from 'react';
|
||||||
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
|
|
||||||
import GlobalHotkeys from './GlobalHotkeys';
|
import GlobalHotkeys from './GlobalHotkeys';
|
||||||
import Toaster from './Toaster';
|
import Toaster from './Toaster';
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
|
|||||||
</Portal>
|
</Portal>
|
||||||
</Grid>
|
</Grid>
|
||||||
<DeleteImageModal />
|
<DeleteImageModal />
|
||||||
<UpdateImageBoardModal />
|
<ChangeBoardModal />
|
||||||
<Toaster />
|
<Toaster />
|
||||||
<GlobalHotkeys />
|
<GlobalHotkeys />
|
||||||
</>
|
</>
|
||||||
|
@ -58,7 +58,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (props.dragData.payloadType === 'IMAGE_NAMES') {
|
if (props.dragData.payloadType === 'IMAGE_DTOS') {
|
||||||
return (
|
return (
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
@ -71,7 +71,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
|
|||||||
...STYLES,
|
...STYLES,
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Heading>{props.dragData.payload.image_names.length}</Heading>
|
<Heading>{props.dragData.payload.imageDTOs.length}</Heading>
|
||||||
<Heading size="sm">Images</Heading>
|
<Heading size="sm">Images</Heading>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -18,27 +18,32 @@ import {
|
|||||||
DragStartEvent,
|
DragStartEvent,
|
||||||
TypesafeDraggableData,
|
TypesafeDraggableData,
|
||||||
} from './typesafeDnd';
|
} from './typesafeDnd';
|
||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
|
||||||
type ImageDndContextProps = PropsWithChildren;
|
type ImageDndContextProps = PropsWithChildren;
|
||||||
|
|
||||||
const ImageDndContext = (props: ImageDndContextProps) => {
|
const ImageDndContext = (props: ImageDndContextProps) => {
|
||||||
const [activeDragData, setActiveDragData] =
|
const [activeDragData, setActiveDragData] =
|
||||||
useState<TypesafeDraggableData | null>(null);
|
useState<TypesafeDraggableData | null>(null);
|
||||||
|
const log = logger('images');
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const handleDragStart = useCallback((event: DragStartEvent) => {
|
const handleDragStart = useCallback(
|
||||||
console.log('dragStart', event.active.data.current);
|
(event: DragStartEvent) => {
|
||||||
|
log.trace({ dragData: event.active.data.current }, 'Drag started');
|
||||||
const activeData = event.active.data.current;
|
const activeData = event.active.data.current;
|
||||||
if (!activeData) {
|
if (!activeData) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
setActiveDragData(activeData);
|
setActiveDragData(activeData);
|
||||||
}, []);
|
},
|
||||||
|
[log]
|
||||||
|
);
|
||||||
|
|
||||||
const handleDragEnd = useCallback(
|
const handleDragEnd = useCallback(
|
||||||
(event: DragEndEvent) => {
|
(event: DragEndEvent) => {
|
||||||
console.log('dragEnd', event.active.data.current);
|
log.trace({ dragData: event.active.data.current }, 'Drag ended');
|
||||||
const overData = event.over?.data.current;
|
const overData = event.over?.data.current;
|
||||||
if (!activeDragData || !overData) {
|
if (!activeDragData || !overData) {
|
||||||
return;
|
return;
|
||||||
@ -46,7 +51,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
|||||||
dispatch(dndDropped({ overData, activeData: activeDragData }));
|
dispatch(dndDropped({ overData, activeData: activeDragData }));
|
||||||
setActiveDragData(null);
|
setActiveDragData(null);
|
||||||
},
|
},
|
||||||
[activeDragData, dispatch]
|
[activeDragData, dispatch, log]
|
||||||
);
|
);
|
||||||
|
|
||||||
const mouseSensor = useSensor(MouseSensor, {
|
const mouseSensor = useSensor(MouseSensor, {
|
||||||
|
@ -11,7 +11,6 @@ import {
|
|||||||
useDraggable as useOriginalDraggable,
|
useDraggable as useOriginalDraggable,
|
||||||
useDroppable as useOriginalDroppable,
|
useDroppable as useOriginalDroppable,
|
||||||
} from '@dnd-kit/core';
|
} from '@dnd-kit/core';
|
||||||
import { BoardId } from 'features/gallery/store/types';
|
|
||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
type BaseDropData = {
|
type BaseDropData = {
|
||||||
@ -54,9 +53,13 @@ export type AddToBatchDropData = BaseDropData & {
|
|||||||
actionType: 'ADD_TO_BATCH';
|
actionType: 'ADD_TO_BATCH';
|
||||||
};
|
};
|
||||||
|
|
||||||
export type MoveBoardDropData = BaseDropData & {
|
export type AddToBoardDropData = BaseDropData & {
|
||||||
actionType: 'MOVE_BOARD';
|
actionType: 'ADD_TO_BOARD';
|
||||||
context: { boardId: BoardId };
|
context: { boardId: string };
|
||||||
|
};
|
||||||
|
|
||||||
|
export type RemoveFromBoardDropData = BaseDropData & {
|
||||||
|
actionType: 'REMOVE_FROM_BOARD';
|
||||||
};
|
};
|
||||||
|
|
||||||
export type TypesafeDroppableData =
|
export type TypesafeDroppableData =
|
||||||
@ -67,7 +70,8 @@ export type TypesafeDroppableData =
|
|||||||
| NodesImageDropData
|
| NodesImageDropData
|
||||||
| AddToBatchDropData
|
| AddToBatchDropData
|
||||||
| NodesMultiImageDropData
|
| NodesMultiImageDropData
|
||||||
| MoveBoardDropData;
|
| AddToBoardDropData
|
||||||
|
| RemoveFromBoardDropData;
|
||||||
|
|
||||||
type BaseDragData = {
|
type BaseDragData = {
|
||||||
id: string;
|
id: string;
|
||||||
@ -78,14 +82,12 @@ export type ImageDraggableData = BaseDragData & {
|
|||||||
payload: { imageDTO: ImageDTO };
|
payload: { imageDTO: ImageDTO };
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ImageNamesDraggableData = BaseDragData & {
|
export type ImageDTOsDraggableData = BaseDragData & {
|
||||||
payloadType: 'IMAGE_NAMES';
|
payloadType: 'IMAGE_DTOS';
|
||||||
payload: { image_names: string[] };
|
payload: { imageDTOs: ImageDTO[] };
|
||||||
};
|
};
|
||||||
|
|
||||||
export type TypesafeDraggableData =
|
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
|
||||||
| ImageDraggableData
|
|
||||||
| ImageNamesDraggableData;
|
|
||||||
|
|
||||||
interface UseDroppableTypesafeArguments
|
interface UseDroppableTypesafeArguments
|
||||||
extends Omit<UseDroppableArguments, 'data'> {
|
extends Omit<UseDroppableArguments, 'data'> {
|
||||||
@ -156,14 +158,39 @@ export const isValidDrop = (
|
|||||||
case 'SET_NODES_IMAGE':
|
case 'SET_NODES_IMAGE':
|
||||||
return payloadType === 'IMAGE_DTO';
|
return payloadType === 'IMAGE_DTO';
|
||||||
case 'SET_MULTI_NODES_IMAGE':
|
case 'SET_MULTI_NODES_IMAGE':
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
case 'ADD_TO_BATCH':
|
case 'ADD_TO_BATCH':
|
||||||
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
case 'MOVE_BOARD': {
|
case 'ADD_TO_BOARD': {
|
||||||
// If the board is the same, don't allow the drop
|
// If the board is the same, don't allow the drop
|
||||||
|
|
||||||
// Check the payload types
|
// Check the payload types
|
||||||
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
|
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
|
if (!isPayloadValid) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the image's board is the board we are dragging onto
|
||||||
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
|
const { imageDTO } = active.data.current.payload;
|
||||||
|
const currentBoard = imageDTO.board_id ?? 'none';
|
||||||
|
const destinationBoard = overData.context.boardId;
|
||||||
|
|
||||||
|
return currentBoard !== destinationBoard;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
|
// TODO (multi-select)
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
case 'REMOVE_FROM_BOARD': {
|
||||||
|
// If the board is the same, don't allow the drop
|
||||||
|
|
||||||
|
// Check the payload types
|
||||||
|
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
|
||||||
if (!isPayloadValid) {
|
if (!isPayloadValid) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -172,20 +199,16 @@ export const isValidDrop = (
|
|||||||
if (payloadType === 'IMAGE_DTO') {
|
if (payloadType === 'IMAGE_DTO') {
|
||||||
const { imageDTO } = active.data.current.payload;
|
const { imageDTO } = active.data.current.payload;
|
||||||
const currentBoard = imageDTO.board_id;
|
const currentBoard = imageDTO.board_id;
|
||||||
const destinationBoard = overData.context.boardId;
|
|
||||||
|
|
||||||
const isSameBoard = currentBoard === destinationBoard;
|
return currentBoard !== 'none';
|
||||||
const isDestinationValid = !currentBoard ? destinationBoard : true;
|
|
||||||
|
|
||||||
return !isSameBoard && isDestinationValid;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (payloadType === 'IMAGE_NAMES') {
|
if (payloadType === 'IMAGE_DTOS') {
|
||||||
// TODO (multi-select)
|
// TODO (multi-select)
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return false;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
|
import { Middleware } from '@reduxjs/toolkit';
|
||||||
import { store } from 'app/store/store';
|
import { store } from 'app/store/store';
|
||||||
|
import { PartialAppConfig } from 'app/types/invokeai';
|
||||||
import React, {
|
import React, {
|
||||||
lazy,
|
lazy,
|
||||||
memo,
|
memo,
|
||||||
@ -7,16 +9,11 @@ import React, {
|
|||||||
useEffect,
|
useEffect,
|
||||||
} from 'react';
|
} from 'react';
|
||||||
import { Provider } from 'react-redux';
|
import { Provider } from 'react-redux';
|
||||||
|
|
||||||
import { PartialAppConfig } from 'app/types/invokeai';
|
|
||||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||||
import Loading from '../../common/components/Loading/Loading';
|
import { $authToken, $baseUrl, $projectId } from 'services/api/client';
|
||||||
|
|
||||||
import { Middleware } from '@reduxjs/toolkit';
|
|
||||||
import { $authToken, $baseUrl } from 'services/api/client';
|
|
||||||
import { socketMiddleware } from 'services/events/middleware';
|
import { socketMiddleware } from 'services/events/middleware';
|
||||||
|
import Loading from '../../common/components/Loading/Loading';
|
||||||
import '../../i18n';
|
import '../../i18n';
|
||||||
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
|
|
||||||
import ImageDndContext from './ImageDnd/ImageDndContext';
|
import ImageDndContext from './ImageDnd/ImageDndContext';
|
||||||
|
|
||||||
const App = lazy(() => import('./App'));
|
const App = lazy(() => import('./App'));
|
||||||
@ -37,6 +34,7 @@ const InvokeAIUI = ({
|
|||||||
config,
|
config,
|
||||||
headerComponent,
|
headerComponent,
|
||||||
middleware,
|
middleware,
|
||||||
|
projectId,
|
||||||
}: Props) => {
|
}: Props) => {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// configure API client token
|
// configure API client token
|
||||||
@ -49,6 +47,11 @@ const InvokeAIUI = ({
|
|||||||
$baseUrl.set(apiUrl);
|
$baseUrl.set(apiUrl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// configure API client project header
|
||||||
|
if (projectId) {
|
||||||
|
$projectId.set(projectId);
|
||||||
|
}
|
||||||
|
|
||||||
// reset dynamically added middlewares
|
// reset dynamically added middlewares
|
||||||
resetMiddlewares();
|
resetMiddlewares();
|
||||||
|
|
||||||
@ -68,8 +71,9 @@ const InvokeAIUI = ({
|
|||||||
// Reset the API client token and base url on unmount
|
// Reset the API client token and base url on unmount
|
||||||
$baseUrl.set(undefined);
|
$baseUrl.set(undefined);
|
||||||
$authToken.set(undefined);
|
$authToken.set(undefined);
|
||||||
|
$projectId.set(undefined);
|
||||||
};
|
};
|
||||||
}, [apiUrl, token, middleware]);
|
}, [apiUrl, token, middleware, projectId]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<React.StrictMode>
|
<React.StrictMode>
|
||||||
@ -77,9 +81,7 @@ const InvokeAIUI = ({
|
|||||||
<React.Suspense fallback={<Loading />}>
|
<React.Suspense fallback={<Loading />}>
|
||||||
<ThemeLocaleProvider>
|
<ThemeLocaleProvider>
|
||||||
<ImageDndContext>
|
<ImageDndContext>
|
||||||
<AddImageToBoardContextProvider>
|
|
||||||
<App config={config} headerComponent={headerComponent} />
|
<App config={config} headerComponent={headerComponent} />
|
||||||
</AddImageToBoardContextProvider>
|
|
||||||
</ImageDndContext>
|
</ImageDndContext>
|
||||||
</ThemeLocaleProvider>
|
</ThemeLocaleProvider>
|
||||||
</React.Suspense>
|
</React.Suspense>
|
||||||
|
@ -1,91 +0,0 @@
|
|||||||
import { useDisclosure } from '@chakra-ui/react';
|
|
||||||
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
|
|
||||||
import { ImageDTO } from 'services/api/types';
|
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
|
||||||
import { useAppDispatch } from '../store/storeHooks';
|
|
||||||
|
|
||||||
export type ImageUsage = {
|
|
||||||
isInitialImage: boolean;
|
|
||||||
isCanvasImage: boolean;
|
|
||||||
isNodesImage: boolean;
|
|
||||||
isControlNetImage: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
type AddImageToBoardContextValue = {
|
|
||||||
/**
|
|
||||||
* Whether the move image dialog is open.
|
|
||||||
*/
|
|
||||||
isOpen: boolean;
|
|
||||||
/**
|
|
||||||
* Closes the move image dialog.
|
|
||||||
*/
|
|
||||||
onClose: () => void;
|
|
||||||
/**
|
|
||||||
* The image pending movement
|
|
||||||
*/
|
|
||||||
image?: ImageDTO;
|
|
||||||
onClickAddToBoard: (image: ImageDTO) => void;
|
|
||||||
handleAddToBoard: (boardId: string) => void;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const AddImageToBoardContext =
|
|
||||||
createContext<AddImageToBoardContextValue>({
|
|
||||||
isOpen: false,
|
|
||||||
onClose: () => undefined,
|
|
||||||
onClickAddToBoard: () => undefined,
|
|
||||||
handleAddToBoard: () => undefined,
|
|
||||||
});
|
|
||||||
|
|
||||||
type Props = PropsWithChildren;
|
|
||||||
|
|
||||||
export const AddImageToBoardContextProvider = (props: Props) => {
|
|
||||||
const [imageToMove, setImageToMove] = useState<ImageDTO>();
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
// Clean up after deleting or dismissing the modal
|
|
||||||
const closeAndClearImageToDelete = useCallback(() => {
|
|
||||||
setImageToMove(undefined);
|
|
||||||
onClose();
|
|
||||||
}, [onClose]);
|
|
||||||
|
|
||||||
const onClickAddToBoard = useCallback(
|
|
||||||
(image?: ImageDTO) => {
|
|
||||||
if (!image) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setImageToMove(image);
|
|
||||||
onOpen();
|
|
||||||
},
|
|
||||||
[setImageToMove, onOpen]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleAddToBoard = useCallback(
|
|
||||||
(boardId: string) => {
|
|
||||||
if (imageToMove) {
|
|
||||||
dispatch(
|
|
||||||
imagesApi.endpoints.addImageToBoard.initiate({
|
|
||||||
imageDTO: imageToMove,
|
|
||||||
board_id: boardId,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
closeAndClearImageToDelete();
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[dispatch, closeAndClearImageToDelete, imageToMove]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<AddImageToBoardContext.Provider
|
|
||||||
value={{
|
|
||||||
isOpen,
|
|
||||||
image: imageToMove,
|
|
||||||
onClose: closeAndClearImageToDelete,
|
|
||||||
onClickAddToBoard,
|
|
||||||
handleAddToBoard,
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{props.children}
|
|
||||||
</AddImageToBoardContext.Provider>
|
|
||||||
);
|
|
||||||
};
|
|
@ -1,8 +0,0 @@
|
|||||||
import { createContext } from 'react';
|
|
||||||
|
|
||||||
type VoidFunc = () => void;
|
|
||||||
|
|
||||||
type ImageUploaderTriggerContextType = VoidFunc | null;
|
|
||||||
|
|
||||||
export const ImageUploaderTriggerContext =
|
|
||||||
createContext<ImageUploaderTriggerContextType>(null);
|
|
@ -23,6 +23,6 @@ const serializationDenylist: {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export const serialize: SerializeFunction = (data, key) => {
|
export const serialize: SerializeFunction = (data, key) => {
|
||||||
const result = omit(data, serializationDenylist[key]);
|
const result = omit(data, serializationDenylist[key] ?? []);
|
||||||
return JSON.stringify(result);
|
return JSON.stringify(result);
|
||||||
};
|
};
|
||||||
|
@ -27,7 +27,8 @@ import {
|
|||||||
addImageDeletedFulfilledListener,
|
addImageDeletedFulfilledListener,
|
||||||
addImageDeletedPendingListener,
|
addImageDeletedPendingListener,
|
||||||
addImageDeletedRejectedListener,
|
addImageDeletedRejectedListener,
|
||||||
addRequestedImageDeletionListener,
|
addRequestedSingleImageDeletionListener,
|
||||||
|
addRequestedMultipleImageDeletionListener,
|
||||||
} from './listeners/imageDeleted';
|
} from './listeners/imageDeleted';
|
||||||
import { addImageDroppedListener } from './listeners/imageDropped';
|
import { addImageDroppedListener } from './listeners/imageDropped';
|
||||||
import {
|
import {
|
||||||
@ -111,7 +112,8 @@ addImageUploadedRejectedListener();
|
|||||||
addInitialImageSelectedListener();
|
addInitialImageSelectedListener();
|
||||||
|
|
||||||
// Image deleted
|
// Image deleted
|
||||||
addRequestedImageDeletionListener();
|
addRequestedSingleImageDeletionListener();
|
||||||
|
addRequestedMultipleImageDeletionListener();
|
||||||
addImageDeletedPendingListener();
|
addImageDeletedPendingListener();
|
||||||
addImageDeletedFulfilledListener();
|
addImageDeletedFulfilledListener();
|
||||||
addImageDeletedRejectedListener();
|
addImageDeletedRejectedListener();
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import {
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
ImageCache,
|
|
||||||
getListImagesUrl,
|
|
||||||
imagesApi,
|
|
||||||
} from 'services/api/endpoints/images';
|
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
import { getListImagesUrl, imagesAdapter } from 'services/api/util';
|
||||||
|
import { ImageCache } from 'services/api/types';
|
||||||
|
|
||||||
export const appStarted = createAction('app/appStarted');
|
export const appStarted = createAction('app/appStarted');
|
||||||
|
|
||||||
@ -34,7 +32,8 @@ export const addFirstListImagesListener = () => {
|
|||||||
|
|
||||||
if (data.ids.length > 0) {
|
if (data.ids.length > 0) {
|
||||||
// Select the first image
|
// Select the first image
|
||||||
dispatch(imageSelected(data.ids[0] as string));
|
const firstImage = imagesAdapter.getSelectors().selectAll(data)[0];
|
||||||
|
dispatch(imageSelected(firstImage ?? null));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -18,7 +18,9 @@ export const addAppConfigReceivedListener = () => {
|
|||||||
const infillMethod = getState().generation.infillMethod;
|
const infillMethod = getState().generation.infillMethod;
|
||||||
|
|
||||||
if (!infill_methods.includes(infillMethod)) {
|
if (!infill_methods.includes(infillMethod)) {
|
||||||
dispatch(setInfillMethod(infill_methods[0]));
|
// if there is no infill method, set it to the first one
|
||||||
|
// if there is no first one... god help us
|
||||||
|
dispatch(setInfillMethod(infill_methods[0] as string));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!nsfw_methods.includes('nsfw_checker')) {
|
if (!nsfw_methods.includes('nsfw_checker')) {
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { getImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
|
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { boardsApi } from '../../../../../services/api/endpoints/boards';
|
|
||||||
|
|
||||||
export const addDeleteBoardAndImagesFulfilledListener = () => {
|
export const addDeleteBoardAndImagesFulfilledListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: boardsApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const { deleted_images } = action.payload;
|
const { deleted_images } = action.payload;
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ import {
|
|||||||
} from 'features/gallery/store/types';
|
} from 'features/gallery/store/types';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
import { imagesSelectors } from 'services/api/util';
|
||||||
|
|
||||||
export const addBoardIdSelectedListener = () => {
|
export const addBoardIdSelectedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -52,8 +53,9 @@ export const addBoardIdSelectedListener = () => {
|
|||||||
queryArgs
|
queryArgs
|
||||||
)(getState());
|
)(getState());
|
||||||
|
|
||||||
if (boardImagesData?.ids.length) {
|
if (boardImagesData) {
|
||||||
dispatch(imageSelected((boardImagesData.ids[0] as string) ?? null));
|
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
||||||
|
dispatch(imageSelected(firstImage ?? null));
|
||||||
} else {
|
} else {
|
||||||
// board has no images - deselect
|
// board has no images - deselect
|
||||||
dispatch(imageSelected(null));
|
dispatch(imageSelected(null));
|
||||||
|
@ -26,6 +26,8 @@ export const addCanvasSavedToGalleryListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const { autoAddBoardId } = state.gallery;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.endpoints.uploadImage.initiate({
|
imagesApi.endpoints.uploadImage.initiate({
|
||||||
file: new File([blob], 'savedCanvas.png', {
|
file: new File([blob], 'savedCanvas.png', {
|
||||||
@ -33,7 +35,7 @@ export const addCanvasSavedToGalleryListener = () => {
|
|||||||
}),
|
}),
|
||||||
image_category: 'general',
|
image_category: 'general',
|
||||||
is_intermediate: false,
|
is_intermediate: false,
|
||||||
board_id: state.gallery.autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
crop_visible: true,
|
crop_visible: true,
|
||||||
postUploadAction: {
|
postUploadAction: {
|
||||||
type: 'TOAST',
|
type: 'TOAST',
|
||||||
|
@ -31,15 +31,20 @@ const predicate: AnyListenerPredicate<RootState> = (
|
|||||||
// do not process if the user just disabled auto-config
|
// do not process if the user just disabled auto-config
|
||||||
if (
|
if (
|
||||||
prevState.controlNet.controlNets[action.payload.controlNetId]
|
prevState.controlNet.controlNets[action.payload.controlNetId]
|
||||||
.shouldAutoConfig === true
|
?.shouldAutoConfig === true
|
||||||
) {
|
) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const { controlImage, processorType, shouldAutoConfig } =
|
const cn = state.controlNet.controlNets[action.payload.controlNetId];
|
||||||
state.controlNet.controlNets[action.payload.controlNetId];
|
|
||||||
|
|
||||||
|
if (!cn) {
|
||||||
|
// something is wrong, the controlNet should exist
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { controlImage, processorType, shouldAutoConfig } = cn;
|
||||||
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
|
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
|
||||||
// do not process if the action is a model change but the processor settings are dirty
|
// do not process if the action is a model change but the processor settings are dirty
|
||||||
return false;
|
return false;
|
||||||
|
@ -17,7 +17,7 @@ export const addControlNetImageProcessedListener = () => {
|
|||||||
const { controlNetId } = action.payload;
|
const { controlNetId } = action.payload;
|
||||||
const controlNet = getState().controlNet.controlNets[controlNetId];
|
const controlNet = getState().controlNet.controlNets[controlNetId];
|
||||||
|
|
||||||
if (!controlNet.controlImage) {
|
if (!controlNet?.controlImage) {
|
||||||
log.error('Unable to process ControlNet image');
|
log.error('Unable to process ControlNet image');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1,57 +1,72 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
import { resetCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||||
|
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
|
||||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
|
|
||||||
import { isModalOpenChanged } from 'features/imageDeletion/store/imageDeletionSlice';
|
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||||
import { clamp } from 'lodash-es';
|
import { clamp } from 'lodash-es';
|
||||||
import { api } from 'services/api';
|
import { api } from 'services/api';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
|
import { imagesAdapter } from 'services/api/util';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
/**
|
export const addRequestedSingleImageDeletionListener = () => {
|
||||||
* Called when the user requests an image deletion
|
|
||||||
*/
|
|
||||||
export const addRequestedImageDeletionListener = () => {
|
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: imageDeletionConfirmed,
|
actionCreator: imageDeletionConfirmed,
|
||||||
effect: async (action, { dispatch, getState, condition }) => {
|
effect: async (action, { dispatch, getState, condition }) => {
|
||||||
const { imageDTO, imageUsage } = action.payload;
|
const { imageDTOs, imagesUsage } = action.payload;
|
||||||
|
|
||||||
|
if (imageDTOs.length !== 1 || imagesUsage.length !== 1) {
|
||||||
|
// handle multiples in separate listener
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const imageDTO = imageDTOs[0];
|
||||||
|
const imageUsage = imagesUsage[0];
|
||||||
|
|
||||||
|
if (!imageDTO || !imageUsage) {
|
||||||
|
// satisfy noUncheckedIndexedAccess
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(isModalOpenChanged(false));
|
dispatch(isModalOpenChanged(false));
|
||||||
|
|
||||||
const { image_name } = imageDTO;
|
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const lastSelectedImage =
|
const lastSelectedImage =
|
||||||
state.gallery.selection[state.gallery.selection.length - 1];
|
state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||||
|
|
||||||
|
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||||
|
const { image_name } = imageDTO;
|
||||||
|
|
||||||
if (lastSelectedImage === image_name) {
|
|
||||||
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
|
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
|
||||||
const { data } =
|
const { data } =
|
||||||
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||||
|
|
||||||
const ids = data?.ids ?? [];
|
const cachedImageDTOs = data
|
||||||
|
? imagesAdapter.getSelectors().selectAll(data)
|
||||||
|
: [];
|
||||||
|
|
||||||
const deletedImageIndex = ids.findIndex(
|
const deletedImageIndex = cachedImageDTOs.findIndex(
|
||||||
(result) => result.toString() === image_name
|
(i) => i.image_name === image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const filteredIds = ids.filter((id) => id.toString() !== image_name);
|
const filteredImageDTOs = cachedImageDTOs.filter(
|
||||||
|
(i) => i.image_name !== image_name
|
||||||
|
);
|
||||||
|
|
||||||
const newSelectedImageIndex = clamp(
|
const newSelectedImageIndex = clamp(
|
||||||
deletedImageIndex,
|
deletedImageIndex,
|
||||||
0,
|
0,
|
||||||
filteredIds.length - 1
|
filteredImageDTOs.length - 1
|
||||||
);
|
);
|
||||||
|
|
||||||
const newSelectedImageId = filteredIds[newSelectedImageIndex];
|
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
|
||||||
|
|
||||||
if (newSelectedImageId) {
|
if (newSelectedImageDTO) {
|
||||||
dispatch(imageSelected(newSelectedImageId as string));
|
dispatch(imageSelected(newSelectedImageDTO));
|
||||||
} else {
|
} else {
|
||||||
dispatch(imageSelected(null));
|
dispatch(imageSelected(null));
|
||||||
}
|
}
|
||||||
@ -97,6 +112,66 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Called when the user requests an image deletion
|
||||||
|
*/
|
||||||
|
export const addRequestedMultipleImageDeletionListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: imageDeletionConfirmed,
|
||||||
|
effect: async (action, { dispatch, getState }) => {
|
||||||
|
const { imageDTOs, imagesUsage } = action.payload;
|
||||||
|
|
||||||
|
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
|
||||||
|
// handle singles in separate listener
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Delete from server
|
||||||
|
await dispatch(
|
||||||
|
imagesApi.endpoints.deleteImages.initiate({ imageDTOs })
|
||||||
|
).unwrap();
|
||||||
|
const state = getState();
|
||||||
|
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
|
||||||
|
const { data } =
|
||||||
|
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||||
|
|
||||||
|
const newSelectedImageDTO = data
|
||||||
|
? imagesAdapter.getSelectors().selectAll(data)[0]
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
if (newSelectedImageDTO) {
|
||||||
|
dispatch(imageSelected(newSelectedImageDTO));
|
||||||
|
} else {
|
||||||
|
dispatch(imageSelected(null));
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(isModalOpenChanged(false));
|
||||||
|
|
||||||
|
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||||
|
|
||||||
|
if (imagesUsage.some((i) => i.isCanvasImage)) {
|
||||||
|
dispatch(resetCanvas());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imagesUsage.some((i) => i.isControlNetImage)) {
|
||||||
|
dispatch(controlNetReset());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imagesUsage.some((i) => i.isInitialImage)) {
|
||||||
|
dispatch(clearInitialImage());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (imagesUsage.some((i) => i.isNodesImage)) {
|
||||||
|
dispatch(nodeEditorReset());
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Called when the actual delete request is sent to the server
|
* Called when the actual delete request is sent to the server
|
||||||
*/
|
*/
|
||||||
|
@ -6,10 +6,7 @@ import {
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import {
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
imageSelected,
|
|
||||||
imagesAddedToBatch,
|
|
||||||
} from 'features/gallery/store/gallerySlice';
|
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
@ -27,19 +24,32 @@ export const addImageDroppedListener = () => {
|
|||||||
const log = logger('images');
|
const log = logger('images');
|
||||||
const { activeData, overData } = action.payload;
|
const { activeData, overData } = action.payload;
|
||||||
|
|
||||||
log.debug({ activeData, overData }, 'Image or selection dropped');
|
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||||
|
log.debug({ activeData, overData }, 'Image dropped');
|
||||||
|
} else if (activeData.payloadType === 'IMAGE_DTOS') {
|
||||||
|
log.debug(
|
||||||
|
{ activeData, overData },
|
||||||
|
`Images (${activeData.payload.imageDTOs.length}) dropped`
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
log.debug({ activeData, overData }, `Unknown payload dropped`);
|
||||||
|
}
|
||||||
|
|
||||||
// set current image
|
/**
|
||||||
|
* Image dropped on current image
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_CURRENT_IMAGE' &&
|
overData.actionType === 'SET_CURRENT_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
dispatch(imageSelected(activeData.payload.imageDTO.image_name));
|
dispatch(imageSelected(activeData.payload.imageDTO));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// set initial image
|
/**
|
||||||
|
* Image dropped on initial image
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_INITIAL_IMAGE' &&
|
overData.actionType === 'SET_INITIAL_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
@ -49,27 +59,9 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// add image to batch
|
/**
|
||||||
if (
|
* Image dropped on ControlNet
|
||||||
overData.actionType === 'ADD_TO_BATCH' &&
|
*/
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
|
||||||
activeData.payload.imageDTO
|
|
||||||
) {
|
|
||||||
dispatch(imagesAddedToBatch([activeData.payload.imageDTO.image_name]));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add multiple images to batch
|
|
||||||
if (
|
|
||||||
overData.actionType === 'ADD_TO_BATCH' &&
|
|
||||||
activeData.payloadType === 'IMAGE_NAMES'
|
|
||||||
) {
|
|
||||||
dispatch(imagesAddedToBatch(activeData.payload.image_names));
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// set control image
|
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
|
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
@ -85,7 +77,9 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// set canvas image
|
/**
|
||||||
|
* Image dropped on Canvas
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
|
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
@ -95,7 +89,9 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// set nodes image
|
/**
|
||||||
|
* Image dropped on node image field
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'SET_NODES_IMAGE' &&
|
overData.actionType === 'SET_NODES_IMAGE' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
@ -112,61 +108,36 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// set multiple nodes images (single image handler)
|
/**
|
||||||
if (
|
* TODO
|
||||||
overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
* Image selection dropped on node image collection field
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
*/
|
||||||
activeData.payload.imageDTO
|
|
||||||
) {
|
|
||||||
const { fieldName, nodeId } = overData.context;
|
|
||||||
dispatch(
|
|
||||||
fieldValueChanged({
|
|
||||||
nodeId,
|
|
||||||
fieldName,
|
|
||||||
value: [activeData.payload.imageDTO],
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// // set multiple nodes images (multiple images handler)
|
|
||||||
// if (
|
// if (
|
||||||
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
||||||
// activeData.payloadType === 'IMAGE_NAMES'
|
// activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
|
// activeData.payload.imageDTO
|
||||||
// ) {
|
// ) {
|
||||||
// const { fieldName, nodeId } = overData.context;
|
// const { fieldName, nodeId } = overData.context;
|
||||||
// dispatch(
|
// dispatch(
|
||||||
// imageCollectionFieldValueChanged({
|
// fieldValueChanged({
|
||||||
// nodeId,
|
// nodeId,
|
||||||
// fieldName,
|
// fieldName,
|
||||||
// value: activeData.payload.image_names.map((image_name) => ({
|
// value: [activeData.payload.imageDTO],
|
||||||
// image_name,
|
|
||||||
// })),
|
|
||||||
// })
|
// })
|
||||||
// );
|
// );
|
||||||
// return;
|
// return;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// add image to board
|
/**
|
||||||
|
* Image dropped on user board
|
||||||
|
*/
|
||||||
if (
|
if (
|
||||||
overData.actionType === 'MOVE_BOARD' &&
|
overData.actionType === 'ADD_TO_BOARD' &&
|
||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
const { imageDTO } = activeData.payload;
|
const { imageDTO } = activeData.payload;
|
||||||
const { boardId } = overData.context;
|
const { boardId } = overData.context;
|
||||||
|
|
||||||
// image was droppe on the "NoBoardBoard"
|
|
||||||
if (!boardId) {
|
|
||||||
dispatch(
|
|
||||||
imagesApi.endpoints.removeImageFromBoard.initiate({
|
|
||||||
imageDTO,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// image was dropped on a user board
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.endpoints.addImageToBoard.initiate({
|
imagesApi.endpoints.addImageToBoard.initiate({
|
||||||
imageDTO,
|
imageDTO,
|
||||||
@ -176,67 +147,58 @@ export const addImageDroppedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// // add gallery selection to board
|
/**
|
||||||
// if (
|
* Image dropped on 'none' board
|
||||||
// overData.actionType === 'MOVE_BOARD' &&
|
*/
|
||||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
if (
|
||||||
// overData.context.boardId
|
overData.actionType === 'REMOVE_FROM_BOARD' &&
|
||||||
// ) {
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
// console.log('adding gallery selection to board');
|
activeData.payload.imageDTO
|
||||||
// const board_id = overData.context.boardId;
|
) {
|
||||||
// dispatch(
|
const { imageDTO } = activeData.payload;
|
||||||
// boardImagesApi.endpoints.addManyBoardImages.initiate({
|
dispatch(
|
||||||
// board_id,
|
imagesApi.endpoints.removeImageFromBoard.initiate({
|
||||||
// image_names: activeData.payload.image_names,
|
imageDTO,
|
||||||
// })
|
})
|
||||||
// );
|
);
|
||||||
// return;
|
return;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// // remove gallery selection from board
|
/**
|
||||||
// if (
|
* Multiple images dropped on user board
|
||||||
// overData.actionType === 'MOVE_BOARD' &&
|
*/
|
||||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
if (
|
||||||
// overData.context.boardId === null
|
overData.actionType === 'ADD_TO_BOARD' &&
|
||||||
// ) {
|
activeData.payloadType === 'IMAGE_DTOS' &&
|
||||||
// console.log('removing gallery selection to board');
|
activeData.payload.imageDTOs
|
||||||
// dispatch(
|
) {
|
||||||
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
|
const { imageDTOs } = activeData.payload;
|
||||||
// image_names: activeData.payload.image_names,
|
const { boardId } = overData.context;
|
||||||
// })
|
dispatch(
|
||||||
// );
|
imagesApi.endpoints.addImagesToBoard.initiate({
|
||||||
// return;
|
imageDTOs,
|
||||||
// }
|
board_id: boardId,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// // add batch selection to board
|
/**
|
||||||
// if (
|
* Multiple images dropped on 'none' board
|
||||||
// overData.actionType === 'MOVE_BOARD' &&
|
*/
|
||||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
if (
|
||||||
// overData.context.boardId
|
overData.actionType === 'REMOVE_FROM_BOARD' &&
|
||||||
// ) {
|
activeData.payloadType === 'IMAGE_DTOS' &&
|
||||||
// const board_id = overData.context.boardId;
|
activeData.payload.imageDTOs
|
||||||
// dispatch(
|
) {
|
||||||
// boardImagesApi.endpoints.addManyBoardImages.initiate({
|
const { imageDTOs } = activeData.payload;
|
||||||
// board_id,
|
dispatch(
|
||||||
// image_names: activeData.payload.image_names,
|
imagesApi.endpoints.removeImagesFromBoard.initiate({
|
||||||
// })
|
imageDTOs,
|
||||||
// );
|
})
|
||||||
// return;
|
);
|
||||||
// }
|
return;
|
||||||
|
}
|
||||||
// // remove batch selection from board
|
|
||||||
// if (
|
|
||||||
// overData.actionType === 'MOVE_BOARD' &&
|
|
||||||
// activeData.payloadType === 'IMAGE_NAMES' &&
|
|
||||||
// overData.context.boardId === null
|
|
||||||
// ) {
|
|
||||||
// dispatch(
|
|
||||||
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
|
|
||||||
// image_names: activeData.payload.image_names,
|
|
||||||
// })
|
|
||||||
// );
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,37 +1,32 @@
|
|||||||
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
|
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
|
||||||
import { selectImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
|
import { selectImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||||
import {
|
import {
|
||||||
imageToDeleteSelected,
|
imagesToDeleteSelected,
|
||||||
isModalOpenChanged,
|
isModalOpenChanged,
|
||||||
} from 'features/imageDeletion/store/imageDeletionSlice';
|
} from 'features/deleteImageModal/store/slice';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
export const addImageToDeleteSelectedListener = () => {
|
export const addImageToDeleteSelectedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
actionCreator: imageToDeleteSelected,
|
actionCreator: imagesToDeleteSelected,
|
||||||
effect: async (action, { dispatch, getState }) => {
|
effect: async (action, { dispatch, getState }) => {
|
||||||
const imageDTO = action.payload;
|
const imageDTOs = action.payload;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const { shouldConfirmOnDelete } = state.system;
|
const { shouldConfirmOnDelete } = state.system;
|
||||||
const imageUsage = selectImageUsage(getState());
|
const imagesUsage = selectImageUsage(getState());
|
||||||
|
|
||||||
if (!imageUsage) {
|
|
||||||
// should never happen
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const isImageInUse =
|
const isImageInUse =
|
||||||
imageUsage.isCanvasImage ||
|
imagesUsage.some((i) => i.isCanvasImage) ||
|
||||||
imageUsage.isInitialImage ||
|
imagesUsage.some((i) => i.isInitialImage) ||
|
||||||
imageUsage.isControlNetImage ||
|
imagesUsage.some((i) => i.isControlNetImage) ||
|
||||||
imageUsage.isNodesImage;
|
imagesUsage.some((i) => i.isNodesImage);
|
||||||
|
|
||||||
if (shouldConfirmOnDelete || isImageInUse) {
|
if (shouldConfirmOnDelete || isImageInUse) {
|
||||||
dispatch(isModalOpenChanged(true));
|
dispatch(isModalOpenChanged(true));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(imageDeletionConfirmed({ imageDTO, imageUsage }));
|
dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage }));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -2,14 +2,13 @@ import { UseToastOptions } from '@chakra-ui/react';
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { imagesAddedToBatch } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { omit } from 'lodash-es';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { imagesApi } from '../../../../../services/api/endpoints/images';
|
import { imagesApi } from '../../../../../services/api/endpoints/images';
|
||||||
import { omit } from 'lodash-es';
|
|
||||||
|
|
||||||
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
|
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
|
||||||
title: 'Image Uploaded',
|
title: 'Image Uploaded',
|
||||||
@ -41,7 +40,7 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
// default action - just upload and alert user
|
// default action - just upload and alert user
|
||||||
if (postUploadAction?.type === 'TOAST') {
|
if (postUploadAction?.type === 'TOAST') {
|
||||||
const { toastOptions } = postUploadAction;
|
const { toastOptions } = postUploadAction;
|
||||||
if (!autoAddBoardId) {
|
if (!autoAddBoardId || autoAddBoardId === 'none') {
|
||||||
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
|
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
|
||||||
} else {
|
} else {
|
||||||
// Add this image to the board
|
// Add this image to the board
|
||||||
@ -121,17 +120,6 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (postUploadAction?.type === 'ADD_TO_BATCH') {
|
|
||||||
dispatch(imagesAddedToBatch([imageDTO.image_name]));
|
|
||||||
dispatch(
|
|
||||||
addToast({
|
|
||||||
...DEFAULT_UPLOADED_TOAST,
|
|
||||||
description: 'Added to batch',
|
|
||||||
})
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -15,7 +15,7 @@ import {
|
|||||||
setShouldUseSDXLRefiner,
|
setShouldUseSDXLRefiner,
|
||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
export const addModelsLoadedListener = () => {
|
export const addModelsLoadedListener = () => {
|
||||||
@ -144,8 +144,9 @@ export const addModelsLoadedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const firstModelId = action.payload.ids[0];
|
const firstModel = vaeModelsAdapter
|
||||||
const firstModel = action.payload.entities[firstModelId];
|
.getSelectors()
|
||||||
|
.selectAll(action.payload)[0];
|
||||||
|
|
||||||
if (!firstModel) {
|
if (!firstModel) {
|
||||||
// No custom VAEs loaded at all; use the default
|
// No custom VAEs loaded at all; use the default
|
||||||
|
@ -8,9 +8,10 @@ import {
|
|||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
import { progressImageSet } from 'features/system/store/systemSlice';
|
import { progressImageSet } from 'features/system/store/systemSlice';
|
||||||
import { imagesAdapter, imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { isImageOutput } from 'services/api/guards';
|
import { isImageOutput } from 'services/api/guards';
|
||||||
import { sessionCanceled } from 'services/api/thunks/session';
|
import { sessionCanceled } from 'services/api/thunks/session';
|
||||||
|
import { imagesAdapter } from 'services/api/util';
|
||||||
import {
|
import {
|
||||||
appSocketInvocationComplete,
|
appSocketInvocationComplete,
|
||||||
socketInvocationComplete,
|
socketInvocationComplete,
|
||||||
@ -67,7 +68,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
const { autoAddBoardId } = gallery;
|
const { autoAddBoardId } = gallery;
|
||||||
if (autoAddBoardId) {
|
if (autoAddBoardId && autoAddBoardId !== 'none') {
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.endpoints.addImageToBoard.initiate({
|
imagesApi.endpoints.addImageToBoard.initiate({
|
||||||
board_id: autoAddBoardId,
|
board_id: autoAddBoardId,
|
||||||
@ -83,10 +84,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
categories: IMAGE_CATEGORIES,
|
categories: IMAGE_CATEGORIES,
|
||||||
},
|
},
|
||||||
(draft) => {
|
(draft) => {
|
||||||
const oldTotal = draft.total;
|
imagesAdapter.addOne(draft, imageDTO);
|
||||||
const newState = imagesAdapter.addOne(draft, imageDTO);
|
|
||||||
const delta = newState.total - oldTotal;
|
|
||||||
draft.total = draft.total + delta;
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
@ -94,8 +92,8 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
imagesApi.util.invalidateTags([
|
imagesApi.util.invalidateTags([
|
||||||
{ type: 'BoardImagesTotal', id: autoAddBoardId ?? 'none' },
|
{ type: 'BoardImagesTotal', id: autoAddBoardId },
|
||||||
{ type: 'BoardAssetsTotal', id: autoAddBoardId ?? 'none' },
|
{ type: 'BoardAssetsTotal', id: autoAddBoardId },
|
||||||
])
|
])
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -110,7 +108,7 @@ export const addInvocationCompleteEventListener = () => {
|
|||||||
} else if (!autoAddBoardId) {
|
} else if (!autoAddBoardId) {
|
||||||
dispatch(galleryViewChanged('images'));
|
dispatch(galleryViewChanged('images'));
|
||||||
}
|
}
|
||||||
dispatch(imageSelected(imageDTO.image_name));
|
dispatch(imageSelected(imageDTO));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,9 +8,9 @@ import {
|
|||||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||||
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
|
||||||
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
|
||||||
import boardsReducer from 'features/gallery/store/boardSlice';
|
|
||||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||||
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
|
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
|
||||||
|
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
|
||||||
import loraReducer from 'features/lora/store/loraSlice';
|
import loraReducer from 'features/lora/store/loraSlice';
|
||||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
@ -43,9 +43,9 @@ const allReducers = {
|
|||||||
ui: uiReducer,
|
ui: uiReducer,
|
||||||
hotkeys: hotkeysReducer,
|
hotkeys: hotkeysReducer,
|
||||||
controlNet: controlNetReducer,
|
controlNet: controlNetReducer,
|
||||||
boards: boardsReducer,
|
|
||||||
dynamicPrompts: dynamicPromptsReducer,
|
dynamicPrompts: dynamicPromptsReducer,
|
||||||
imageDeletion: imageDeletionReducer,
|
deleteImageModal: deleteImageModalReducer,
|
||||||
|
changeBoardModal: changeBoardModalReducer,
|
||||||
lora: loraReducer,
|
lora: loraReducer,
|
||||||
modelmanager: modelmanagerReducer,
|
modelmanager: modelmanagerReducer,
|
||||||
sdxl: sdxlReducer,
|
sdxl: sdxlReducer,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Flex, Text, useColorMode } from '@chakra-ui/react';
|
import { Box, Flex, useColorMode } from '@chakra-ui/react';
|
||||||
import { motion } from 'framer-motion';
|
import { motion } from 'framer-motion';
|
||||||
import { ReactNode, memo, useRef } from 'react';
|
import { ReactNode, memo, useRef } from 'react';
|
||||||
import { mode } from 'theme/util/mode';
|
import { mode } from 'theme/util/mode';
|
||||||
@ -74,7 +74,7 @@ export const IAIDropOverlay = (props: Props) => {
|
|||||||
justifyContent: 'center',
|
justifyContent: 'center',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<Text
|
<Box
|
||||||
sx={{
|
sx={{
|
||||||
fontSize: '2xl',
|
fontSize: '2xl',
|
||||||
fontWeight: 600,
|
fontWeight: 600,
|
||||||
@ -87,7 +87,7 @@ export const IAIDropOverlay = (props: Props) => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{label}
|
{label}
|
||||||
</Text>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
</Flex>
|
</Flex>
|
||||||
</motion.div>
|
</motion.div>
|
||||||
|
@ -53,7 +53,9 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
|
|||||||
// wrap onChange to clear search value on select
|
// wrap onChange to clear search value on select
|
||||||
const handleChange = useCallback(
|
const handleChange = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
setSearchValue('');
|
// cannot figure out why we were doing this, but it was causing an issue where if you
|
||||||
|
// select the currently-selected item, it reset the search value to empty
|
||||||
|
// setSearchValue('');
|
||||||
|
|
||||||
if (!onChange) {
|
if (!onChange) {
|
||||||
return;
|
return;
|
||||||
|
@ -78,7 +78,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
|||||||
image_category: 'user',
|
image_category: 'user',
|
||||||
is_intermediate: false,
|
is_intermediate: false,
|
||||||
postUploadAction,
|
postUploadAction,
|
||||||
board_id: autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[autoAddBoardId, postUploadAction, uploadImage]
|
[autoAddBoardId, postUploadAction, uploadImage]
|
||||||
|
@ -49,7 +49,7 @@ export const useImageUploadButton = ({
|
|||||||
image_category: 'user',
|
image_category: 'user',
|
||||||
is_intermediate: false,
|
is_intermediate: false,
|
||||||
postUploadAction: postUploadAction ?? { type: 'TOAST' },
|
postUploadAction: postUploadAction ?? { type: 'TOAST' },
|
||||||
board_id: autoAddBoardId,
|
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[autoAddBoardId, postUploadAction, uploadImage]
|
[autoAddBoardId, postUploadAction, uploadImage]
|
||||||
|
@ -33,6 +33,10 @@ const useColorPicker = () => {
|
|||||||
1
|
1
|
||||||
).data;
|
).data;
|
||||||
|
|
||||||
|
if (!(a && r && g && b)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(setColorPickerColor({ r, g, b, a }));
|
dispatch(setColorPickerColor({ r, g, b, a }));
|
||||||
},
|
},
|
||||||
commitColorUnderCursor: () => {
|
commitColorUnderCursor: () => {
|
||||||
|
@ -727,10 +727,13 @@ export const canvasSlice = createSlice({
|
|||||||
state.pastLayerStates.shift();
|
state.pastLayerStates.shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
state.layerState.objects.push({
|
const imageToCommit = images[selectedImageIndex];
|
||||||
...images[selectedImageIndex],
|
|
||||||
});
|
|
||||||
|
|
||||||
|
if (imageToCommit) {
|
||||||
|
state.layerState.objects.push({
|
||||||
|
...imageToCommit,
|
||||||
|
});
|
||||||
|
}
|
||||||
state.layerState.stagingArea = {
|
state.layerState.stagingArea = {
|
||||||
...initialLayerState.stagingArea,
|
...initialLayerState.stagingArea,
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,132 @@
|
|||||||
|
import {
|
||||||
|
AlertDialog,
|
||||||
|
AlertDialogBody,
|
||||||
|
AlertDialogContent,
|
||||||
|
AlertDialogFooter,
|
||||||
|
AlertDialogHeader,
|
||||||
|
AlertDialogOverlay,
|
||||||
|
Flex,
|
||||||
|
Text,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
|
import { memo, useCallback, useMemo, useRef, useState } from 'react';
|
||||||
|
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||||
|
import {
|
||||||
|
useAddImagesToBoardMutation,
|
||||||
|
useRemoveImagesFromBoardMutation,
|
||||||
|
} from 'services/api/endpoints/images';
|
||||||
|
import { changeBoardReset, isModalOpenChanged } from '../store/slice';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ changeBoardModal }) => {
|
||||||
|
const { isModalOpen, imagesToChange } = changeBoardModal;
|
||||||
|
|
||||||
|
return {
|
||||||
|
isModalOpen,
|
||||||
|
imagesToChange,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ChangeBoardModal = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const [selectedBoard, setSelectedBoard] = useState<string | null>();
|
||||||
|
const { data: boards, isFetching } = useListAllBoardsQuery();
|
||||||
|
const { imagesToChange, isModalOpen } = useAppSelector(selector);
|
||||||
|
const [addImagesToBoard] = useAddImagesToBoardMutation();
|
||||||
|
const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation();
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
const data: { label: string; value: string }[] = [
|
||||||
|
{ label: 'Uncategorized', value: 'none' },
|
||||||
|
];
|
||||||
|
(boards ?? []).forEach((board) =>
|
||||||
|
data.push({
|
||||||
|
label: board.board_name,
|
||||||
|
value: board.board_id,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [boards]);
|
||||||
|
|
||||||
|
const handleClose = useCallback(() => {
|
||||||
|
dispatch(changeBoardReset());
|
||||||
|
dispatch(isModalOpenChanged(false));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
const handleChangeBoard = useCallback(() => {
|
||||||
|
if (!imagesToChange.length || !selectedBoard) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selectedBoard === 'none') {
|
||||||
|
removeImagesFromBoard({ imageDTOs: imagesToChange });
|
||||||
|
} else {
|
||||||
|
addImagesToBoard({
|
||||||
|
imageDTOs: imagesToChange,
|
||||||
|
board_id: selectedBoard,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
setSelectedBoard(null);
|
||||||
|
dispatch(changeBoardReset());
|
||||||
|
}, [
|
||||||
|
addImagesToBoard,
|
||||||
|
dispatch,
|
||||||
|
imagesToChange,
|
||||||
|
removeImagesFromBoard,
|
||||||
|
selectedBoard,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<AlertDialog
|
||||||
|
isOpen={isModalOpen}
|
||||||
|
onClose={handleClose}
|
||||||
|
leastDestructiveRef={cancelRef}
|
||||||
|
isCentered
|
||||||
|
>
|
||||||
|
<AlertDialogOverlay>
|
||||||
|
<AlertDialogContent>
|
||||||
|
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||||
|
Change Board
|
||||||
|
</AlertDialogHeader>
|
||||||
|
|
||||||
|
<AlertDialogBody>
|
||||||
|
<Flex sx={{ flexDir: 'column', gap: 4 }}>
|
||||||
|
<Text>
|
||||||
|
Moving {`${imagesToChange.length}`} image
|
||||||
|
{`${imagesToChange.length > 1 ? 's' : ''}`} to board:
|
||||||
|
</Text>
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
placeholder={isFetching ? 'Loading...' : 'Select Board'}
|
||||||
|
disabled={isFetching}
|
||||||
|
onChange={(v) => setSelectedBoard(v)}
|
||||||
|
value={selectedBoard}
|
||||||
|
data={data}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
</AlertDialogBody>
|
||||||
|
<AlertDialogFooter>
|
||||||
|
<IAIButton ref={cancelRef} onClick={handleClose}>
|
||||||
|
Cancel
|
||||||
|
</IAIButton>
|
||||||
|
<IAIButton colorScheme="accent" onClick={handleChangeBoard} ml={3}>
|
||||||
|
Move
|
||||||
|
</IAIButton>
|
||||||
|
</AlertDialogFooter>
|
||||||
|
</AlertDialogContent>
|
||||||
|
</AlertDialogOverlay>
|
||||||
|
</AlertDialog>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ChangeBoardModal);
|
@ -0,0 +1,6 @@
|
|||||||
|
import { ChangeBoardModalState } from './types';
|
||||||
|
|
||||||
|
export const initialState: ChangeBoardModalState = {
|
||||||
|
isModalOpen: false,
|
||||||
|
imagesToChange: [],
|
||||||
|
};
|
@ -0,0 +1,25 @@
|
|||||||
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
import { initialState } from './initialState';
|
||||||
|
|
||||||
|
const changeBoardModal = createSlice({
|
||||||
|
name: 'changeBoardModal',
|
||||||
|
initialState,
|
||||||
|
reducers: {
|
||||||
|
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.isModalOpen = action.payload;
|
||||||
|
},
|
||||||
|
imagesToChangeSelected: (state, action: PayloadAction<ImageDTO[]>) => {
|
||||||
|
state.imagesToChange = action.payload;
|
||||||
|
},
|
||||||
|
changeBoardReset: (state) => {
|
||||||
|
state.imagesToChange = [];
|
||||||
|
state.isModalOpen = false;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } =
|
||||||
|
changeBoardModal.actions;
|
||||||
|
|
||||||
|
export default changeBoardModal.reducer;
|
@ -0,0 +1,6 @@
|
|||||||
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
|
export type ChangeBoardModalState = {
|
||||||
|
isModalOpen: boolean;
|
||||||
|
imagesToChange: ImageDTO[];
|
||||||
|
};
|
@ -3,6 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { FaCopy, FaTrash } from 'react-icons/fa';
|
import { FaCopy, FaTrash } from 'react-icons/fa';
|
||||||
import {
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
controlNetDuplicated,
|
controlNetDuplicated,
|
||||||
controlNetRemoved,
|
controlNetRemoved,
|
||||||
controlNetToggled,
|
controlNetToggled,
|
||||||
@ -27,18 +28,27 @@ import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcesso
|
|||||||
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
|
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
|
||||||
|
|
||||||
type ControlNetProps = {
|
type ControlNetProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ControlNet = (props: ControlNetProps) => {
|
const ControlNet = (props: ControlNetProps) => {
|
||||||
const { controlNetId } = props;
|
const { controlNet } = props;
|
||||||
|
const { controlNetId } = controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ controlNet }) => {
|
({ controlNet }) => {
|
||||||
const { isEnabled, shouldAutoConfig } =
|
const cn = controlNet.controlNets[controlNetId];
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
|
if (!cn) {
|
||||||
|
return {
|
||||||
|
isEnabled: false,
|
||||||
|
shouldAutoConfig: false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const { isEnabled, shouldAutoConfig } = cn;
|
||||||
|
|
||||||
return { isEnabled, shouldAutoConfig };
|
return { isEnabled, shouldAutoConfig };
|
||||||
},
|
},
|
||||||
@ -96,7 +106,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
transitionDuration: '0.1s',
|
transitionDuration: '0.1s',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ParamControlNetModel controlNetId={controlNetId} />
|
<ParamControlNetModel controlNet={controlNet} />
|
||||||
</Box>
|
</Box>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
size="sm"
|
size="sm"
|
||||||
@ -171,8 +181,8 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
justifyContent: 'space-between',
|
justifyContent: 'space-between',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ParamControlNetWeight controlNetId={controlNetId} />
|
<ParamControlNetWeight controlNet={controlNet} />
|
||||||
<ParamControlNetBeginEnd controlNetId={controlNetId} />
|
<ParamControlNetBeginEnd controlNet={controlNet} />
|
||||||
</Flex>
|
</Flex>
|
||||||
{!isExpanded && (
|
{!isExpanded && (
|
||||||
<Flex
|
<Flex
|
||||||
@ -184,22 +194,22 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
aspectRatio: '1/1',
|
aspectRatio: '1/1',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ControlNetImagePreview controlNetId={controlNetId} height={28} />
|
<ControlNetImagePreview controlNet={controlNet} height={28} />
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex sx={{ gap: 2 }}>
|
<Flex sx={{ gap: 2 }}>
|
||||||
<ParamControlNetControlMode controlNetId={controlNetId} />
|
<ParamControlNetControlMode controlNet={controlNet} />
|
||||||
<ParamControlNetResizeMode controlNetId={controlNetId} />
|
<ParamControlNetResizeMode controlNet={controlNet} />
|
||||||
</Flex>
|
</Flex>
|
||||||
<ParamControlNetProcessorSelect controlNetId={controlNetId} />
|
<ParamControlNetProcessorSelect controlNet={controlNet} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
{isExpanded && (
|
{isExpanded && (
|
||||||
<>
|
<>
|
||||||
<ControlNetImagePreview controlNetId={controlNetId} height="392px" />
|
<ControlNetImagePreview controlNet={controlNet} height="392px" />
|
||||||
<ParamControlNetShouldAutoConfig controlNetId={controlNetId} />
|
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
|
||||||
<ControlNetProcessorComponent controlNetId={controlNetId} />
|
<ControlNetProcessorComponent controlNet={controlNet} />
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -12,50 +12,41 @@ import IAIDndImage from 'common/components/IAIDndImage';
|
|||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { PostUploadAction } from 'services/api/types';
|
import { PostUploadAction } from 'services/api/types';
|
||||||
import { controlNetImageChanged } from '../store/controlNetSlice';
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetImageChanged,
|
||||||
|
} from '../store/controlNetSlice';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
height: SystemStyleObject['h'];
|
height: SystemStyleObject['h'];
|
||||||
};
|
};
|
||||||
|
|
||||||
const ControlNetImagePreview = (props: Props) => {
|
const selector = createSelector(
|
||||||
const { height, controlNetId } = props;
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ controlNet }) => {
|
({ controlNet }) => {
|
||||||
const { pendingControlImages } = controlNet;
|
const { pendingControlImages } = controlNet;
|
||||||
const {
|
|
||||||
controlImage,
|
|
||||||
processedControlImage,
|
|
||||||
processorType,
|
|
||||||
isEnabled,
|
|
||||||
} = controlNet.controlNets[controlNetId];
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
controlImageName: controlImage,
|
|
||||||
processedControlImageName: processedControlImage,
|
|
||||||
processorType,
|
|
||||||
isEnabled,
|
|
||||||
pendingControlImages,
|
pendingControlImages,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const ControlNetImagePreview = (props: Props) => {
|
||||||
|
const { height } = props;
|
||||||
const {
|
const {
|
||||||
controlImageName,
|
controlImage: controlImageName,
|
||||||
processedControlImageName,
|
processedControlImage: processedControlImageName,
|
||||||
processorType,
|
processorType,
|
||||||
pendingControlImages,
|
|
||||||
isEnabled,
|
isEnabled,
|
||||||
} = useAppSelector(selector);
|
controlNetId,
|
||||||
|
} = props.controlNet;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { pendingControlImages } = useAppSelector(selector);
|
||||||
|
|
||||||
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
const [isMouseOverImage, setIsMouseOverImage] = useState(false);
|
||||||
|
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { memo } from 'react';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { ControlNetConfig } from '../store/controlNetSlice';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import { memo, useMemo } from 'react';
|
|
||||||
import CannyProcessor from './processors/CannyProcessor';
|
import CannyProcessor from './processors/CannyProcessor';
|
||||||
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
|
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
|
||||||
import HedProcessor from './processors/HedProcessor';
|
import HedProcessor from './processors/HedProcessor';
|
||||||
@ -17,28 +14,11 @@ import PidiProcessor from './processors/PidiProcessor';
|
|||||||
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
|
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
|
||||||
|
|
||||||
export type ControlNetProcessorProps = {
|
export type ControlNetProcessorProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
|
||||||
const { controlNetId } = props;
|
const { controlNetId, isEnabled, processorNode } = props.controlNet;
|
||||||
|
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { isEnabled, processorNode } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
|
|
||||||
return { isEnabled, processorNode };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { isEnabled, processorNode } = useAppSelector(selector);
|
|
||||||
|
|
||||||
if (processorNode.type === 'canny_image_processor') {
|
if (processorNode.type === 'canny_image_processor') {
|
||||||
return (
|
return (
|
||||||
|
@ -1,34 +1,19 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { controlNetAutoConfigToggled } from 'features/controlNet/store/controlNetSlice';
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetAutoConfigToggled,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
const ParamControlNetShouldAutoConfig = (props: Props) => {
|
||||||
const { controlNetId } = props;
|
const { controlNetId, isEnabled, shouldAutoConfig } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { isEnabled, shouldAutoConfig } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { isEnabled, shouldAutoConfig };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
|
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const handleShouldAutoConfigChanged = useCallback(() => {
|
const handleShouldAutoConfigChanged = useCallback(() => {
|
||||||
|
@ -9,48 +9,39 @@ import {
|
|||||||
RangeSliderTrack,
|
RangeSliderTrack,
|
||||||
Tooltip,
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import {
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
controlNetBeginStepPctChanged,
|
controlNetBeginStepPctChanged,
|
||||||
controlNetEndStepPctChanged,
|
controlNetEndStepPctChanged,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
|
||||||
|
|
||||||
const ParamControlNetBeginEnd = (props: Props) => {
|
const ParamControlNetBeginEnd = (props: Props) => {
|
||||||
const { controlNetId } = props;
|
const { beginStepPct, endStepPct, isEnabled, controlNetId } =
|
||||||
|
props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { beginStepPct, endStepPct, isEnabled } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { beginStepPct, endStepPct, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { beginStepPct, endStepPct, isEnabled } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const handleStepPctChanged = useCallback(
|
const handleStepPctChanged = useCallback(
|
||||||
(v: number[]) => {
|
(v: number[]) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
controlNetBeginStepPctChanged({ controlNetId, beginStepPct: v[0] })
|
controlNetBeginStepPctChanged({
|
||||||
|
controlNetId,
|
||||||
|
beginStepPct: v[0] as number,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
dispatch(
|
||||||
|
controlNetEndStepPctChanged({
|
||||||
|
controlNetId,
|
||||||
|
endStepPct: v[1] as number,
|
||||||
|
})
|
||||||
);
|
);
|
||||||
dispatch(controlNetEndStepPctChanged({ controlNetId, endStepPct: v[1] }));
|
|
||||||
},
|
},
|
||||||
[controlNetId, dispatch]
|
[controlNetId, dispatch]
|
||||||
);
|
);
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import {
|
import {
|
||||||
ControlModes,
|
ControlModes,
|
||||||
|
ControlNetConfig,
|
||||||
controlNetControlModeChanged,
|
controlNetControlModeChanged,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
type ParamControlNetControlModeProps = {
|
type ParamControlNetControlModeProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const CONTROL_MODE_DATA = [
|
const CONTROL_MODE_DATA = [
|
||||||
@ -23,23 +21,8 @@ const CONTROL_MODE_DATA = [
|
|||||||
export default function ParamControlNetControlMode(
|
export default function ParamControlNetControlMode(
|
||||||
props: ParamControlNetControlModeProps
|
props: ParamControlNetControlModeProps
|
||||||
) {
|
) {
|
||||||
const { controlNetId } = props;
|
const { controlMode, isEnabled, controlNetId } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { controlMode, isEnabled } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { controlMode, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { controlMode, isEnabled } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const handleControlModeChange = useCallback(
|
const handleControlModeChange = useCallback(
|
||||||
(controlMode: ControlModes) => {
|
(controlMode: ControlModes) => {
|
||||||
|
@ -5,7 +5,10 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
|
||||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetModelChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
@ -14,30 +17,24 @@ import { memo, useCallback, useMemo } from 'react';
|
|||||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type ParamControlNetModelProps = {
|
type ParamControlNetModelProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ generation }) => {
|
||||||
|
const { model } = generation;
|
||||||
|
return { mainModel: model };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||||
const { controlNetId } = props;
|
const { controlNetId, model: controlNetModel, isEnabled } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
|
|
||||||
const selector = useMemo(
|
const { mainModel } = useAppSelector(selector);
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ generation, controlNet }) => {
|
|
||||||
const { model } = generation;
|
|
||||||
const controlNetModel = controlNet.controlNets[controlNetId]?.model;
|
|
||||||
const isEnabled = controlNet.controlNets[controlNetId]?.isEnabled;
|
|
||||||
return { mainModel: model, controlNetModel, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { mainModel, controlNetModel, isEnabled } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||||
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIMantineSearchableSelect, {
|
import IAIMantineSearchableSelect, {
|
||||||
IAISelectDataType,
|
IAISelectDataType,
|
||||||
@ -9,13 +8,16 @@ import IAIMantineSearchableSelect, {
|
|||||||
import { configSelector } from 'features/system/store/configSelectors';
|
import { configSelector } from 'features/system/store/configSelectors';
|
||||||
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
|
controlNetProcessorTypeChanged,
|
||||||
|
} from '../../store/controlNetSlice';
|
||||||
import { ControlNetProcessorType } from '../../store/types';
|
import { ControlNetProcessorType } from '../../store/types';
|
||||||
|
|
||||||
type ParamControlNetProcessorSelectProps = {
|
type ParamControlNetProcessorSelectProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -52,23 +54,9 @@ const ParamControlNetProcessorSelect = (
|
|||||||
props: ParamControlNetProcessorSelectProps
|
props: ParamControlNetProcessorSelectProps
|
||||||
) => {
|
) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { controlNetId } = props;
|
const { controlNetId, isEnabled, processorNode } = props.controlNet;
|
||||||
const processorNodeSelector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { isEnabled, processorNode } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { isEnabled, processorNode };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
const isBusy = useAppSelector(selectIsBusy);
|
const isBusy = useAppSelector(selectIsBusy);
|
||||||
const controlNetProcessors = useAppSelector(selector);
|
const controlNetProcessors = useAppSelector(selector);
|
||||||
const { isEnabled, processorNode } = useAppSelector(processorNodeSelector);
|
|
||||||
|
|
||||||
const handleProcessorTypeChanged = useCallback(
|
const handleProcessorTypeChanged = useCallback(
|
||||||
(v: string | null) => {
|
(v: string | null) => {
|
||||||
|
@ -1,16 +1,14 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import {
|
import {
|
||||||
|
ControlNetConfig,
|
||||||
ResizeModes,
|
ResizeModes,
|
||||||
controlNetResizeModeChanged,
|
controlNetResizeModeChanged,
|
||||||
} from 'features/controlNet/store/controlNetSlice';
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
type ParamControlNetResizeModeProps = {
|
type ParamControlNetResizeModeProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const RESIZE_MODE_DATA = [
|
const RESIZE_MODE_DATA = [
|
||||||
@ -22,23 +20,8 @@ const RESIZE_MODE_DATA = [
|
|||||||
export default function ParamControlNetResizeMode(
|
export default function ParamControlNetResizeMode(
|
||||||
props: ParamControlNetResizeModeProps
|
props: ParamControlNetResizeModeProps
|
||||||
) {
|
) {
|
||||||
const { controlNetId } = props;
|
const { resizeMode, isEnabled, controlNetId } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { resizeMode, isEnabled } =
|
|
||||||
controlNet.controlNets[controlNetId];
|
|
||||||
return { resizeMode, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { resizeMode, isEnabled } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const handleResizeModeChange = useCallback(
|
const handleResizeModeChange = useCallback(
|
||||||
(resizeMode: ResizeModes) => {
|
(resizeMode: ResizeModes) => {
|
||||||
|
@ -1,32 +1,18 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { controlNetWeightChanged } from 'features/controlNet/store/controlNetSlice';
|
import {
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
ControlNetConfig,
|
||||||
|
controlNetWeightChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
type ParamControlNetWeightProps = {
|
type ParamControlNetWeightProps = {
|
||||||
controlNetId: string;
|
controlNet: ControlNetConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
const ParamControlNetWeight = (props: ParamControlNetWeightProps) => {
|
||||||
const { controlNetId } = props;
|
const { weight, isEnabled, controlNetId } = props.controlNet;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selector = useMemo(
|
|
||||||
() =>
|
|
||||||
createSelector(
|
|
||||||
stateSelector,
|
|
||||||
({ controlNet }) => {
|
|
||||||
const { weight, isEnabled } = controlNet.controlNets[controlNetId];
|
|
||||||
return { weight, isEnabled };
|
|
||||||
},
|
|
||||||
defaultSelectorOptions
|
|
||||||
),
|
|
||||||
[controlNetId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const { weight, isEnabled } = useAppSelector(selector);
|
|
||||||
const handleWeightChanged = useCallback(
|
const handleWeightChanged = useCallback(
|
||||||
(weight: number) => {
|
(weight: number) => {
|
||||||
dispatch(controlNetWeightChanged({ controlNetId, weight }));
|
dispatch(controlNetWeightChanged({ controlNetId, weight }));
|
||||||
|
@ -4,7 +4,7 @@ import {
|
|||||||
} from './types';
|
} from './types';
|
||||||
|
|
||||||
type ControlNetProcessorsDict = Record<
|
type ControlNetProcessorsDict = Record<
|
||||||
string,
|
ControlNetProcessorType,
|
||||||
{
|
{
|
||||||
type: ControlNetProcessorType | 'none';
|
type: ControlNetProcessorType | 'none';
|
||||||
label: string;
|
label: string;
|
||||||
|
@ -96,8 +96,11 @@ export const controlNetSlice = createSlice({
|
|||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { sourceControlNetId, newControlNetId } = action.payload;
|
const { sourceControlNetId, newControlNetId } = action.payload;
|
||||||
|
const oldControlNet = state.controlNets[sourceControlNetId];
|
||||||
const newControlnet = cloneDeep(state.controlNets[sourceControlNetId]);
|
if (!oldControlNet) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const newControlnet = cloneDeep(oldControlNet);
|
||||||
newControlnet.controlNetId = newControlNetId;
|
newControlnet.controlNetId = newControlNetId;
|
||||||
state.controlNets[newControlNetId] = newControlnet;
|
state.controlNets[newControlNetId] = newControlnet;
|
||||||
},
|
},
|
||||||
@ -124,8 +127,11 @@ export const controlNetSlice = createSlice({
|
|||||||
action: PayloadAction<{ controlNetId: string }>
|
action: PayloadAction<{ controlNetId: string }>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId } = action.payload;
|
const { controlNetId } = action.payload;
|
||||||
state.controlNets[controlNetId].isEnabled =
|
const cn = state.controlNets[controlNetId];
|
||||||
!state.controlNets[controlNetId].isEnabled;
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
cn.isEnabled = !cn.isEnabled;
|
||||||
},
|
},
|
||||||
controlNetImageChanged: (
|
controlNetImageChanged: (
|
||||||
state,
|
state,
|
||||||
@ -135,12 +141,14 @@ export const controlNetSlice = createSlice({
|
|||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, controlImage } = action.payload;
|
const { controlNetId, controlImage } = action.payload;
|
||||||
state.controlNets[controlNetId].controlImage = controlImage;
|
const cn = state.controlNets[controlNetId];
|
||||||
state.controlNets[controlNetId].processedControlImage = null;
|
if (!cn) {
|
||||||
if (
|
return;
|
||||||
controlImage !== null &&
|
}
|
||||||
state.controlNets[controlNetId].processorType !== 'none'
|
|
||||||
) {
|
cn.controlImage = controlImage;
|
||||||
|
cn.processedControlImage = null;
|
||||||
|
if (controlImage !== null && cn.processorType !== 'none') {
|
||||||
state.pendingControlImages.push(controlNetId);
|
state.pendingControlImages.push(controlNetId);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -152,8 +160,12 @@ export const controlNetSlice = createSlice({
|
|||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, processedControlImage } = action.payload;
|
const { controlNetId, processedControlImage } = action.payload;
|
||||||
state.controlNets[controlNetId].processedControlImage =
|
const cn = state.controlNets[controlNetId];
|
||||||
processedControlImage;
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cn.processedControlImage = processedControlImage;
|
||||||
state.pendingControlImages = state.pendingControlImages.filter(
|
state.pendingControlImages = state.pendingControlImages.filter(
|
||||||
(id) => id !== controlNetId
|
(id) => id !== controlNetId
|
||||||
);
|
);
|
||||||
@ -166,10 +178,15 @@ export const controlNetSlice = createSlice({
|
|||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, model } = action.payload;
|
const { controlNetId, model } = action.payload;
|
||||||
state.controlNets[controlNetId].model = model;
|
const cn = state.controlNets[controlNetId];
|
||||||
state.controlNets[controlNetId].processedControlImage = null;
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
cn.model = model;
|
||||||
|
cn.processedControlImage = null;
|
||||||
|
|
||||||
|
if (cn.shouldAutoConfig) {
|
||||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||||
|
|
||||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
@ -180,14 +197,13 @@ export const controlNetSlice = createSlice({
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (processorType) {
|
if (processorType) {
|
||||||
state.controlNets[controlNetId].processorType = processorType;
|
cn.processorType = processorType;
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||||
processorType
|
.default as RequiredControlNetProcessorNode;
|
||||||
].default as RequiredControlNetProcessorNode;
|
|
||||||
} else {
|
} else {
|
||||||
state.controlNets[controlNetId].processorType = 'none';
|
cn.processorType = 'none';
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS
|
cn.processorNode = CONTROLNET_PROCESSORS.none
|
||||||
.none.default as RequiredControlNetProcessorNode;
|
.default as RequiredControlNetProcessorNode;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -196,28 +212,48 @@ export const controlNetSlice = createSlice({
|
|||||||
action: PayloadAction<{ controlNetId: string; weight: number }>
|
action: PayloadAction<{ controlNetId: string; weight: number }>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, weight } = action.payload;
|
const { controlNetId, weight } = action.payload;
|
||||||
state.controlNets[controlNetId].weight = weight;
|
const cn = state.controlNets[controlNetId];
|
||||||
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cn.weight = weight;
|
||||||
},
|
},
|
||||||
controlNetBeginStepPctChanged: (
|
controlNetBeginStepPctChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ controlNetId: string; beginStepPct: number }>
|
action: PayloadAction<{ controlNetId: string; beginStepPct: number }>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, beginStepPct } = action.payload;
|
const { controlNetId, beginStepPct } = action.payload;
|
||||||
state.controlNets[controlNetId].beginStepPct = beginStepPct;
|
const cn = state.controlNets[controlNetId];
|
||||||
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cn.beginStepPct = beginStepPct;
|
||||||
},
|
},
|
||||||
controlNetEndStepPctChanged: (
|
controlNetEndStepPctChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ controlNetId: string; endStepPct: number }>
|
action: PayloadAction<{ controlNetId: string; endStepPct: number }>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, endStepPct } = action.payload;
|
const { controlNetId, endStepPct } = action.payload;
|
||||||
state.controlNets[controlNetId].endStepPct = endStepPct;
|
const cn = state.controlNets[controlNetId];
|
||||||
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cn.endStepPct = endStepPct;
|
||||||
},
|
},
|
||||||
controlNetControlModeChanged: (
|
controlNetControlModeChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
|
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, controlMode } = action.payload;
|
const { controlNetId, controlMode } = action.payload;
|
||||||
state.controlNets[controlNetId].controlMode = controlMode;
|
const cn = state.controlNets[controlNetId];
|
||||||
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cn.controlMode = controlMode;
|
||||||
},
|
},
|
||||||
controlNetResizeModeChanged: (
|
controlNetResizeModeChanged: (
|
||||||
state,
|
state,
|
||||||
@ -227,7 +263,12 @@ export const controlNetSlice = createSlice({
|
|||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, resizeMode } = action.payload;
|
const { controlNetId, resizeMode } = action.payload;
|
||||||
state.controlNets[controlNetId].resizeMode = resizeMode;
|
const cn = state.controlNets[controlNetId];
|
||||||
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cn.resizeMode = resizeMode;
|
||||||
},
|
},
|
||||||
controlNetProcessorParamsChanged: (
|
controlNetProcessorParamsChanged: (
|
||||||
state,
|
state,
|
||||||
@ -240,12 +281,17 @@ export const controlNetSlice = createSlice({
|
|||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, changes } = action.payload;
|
const { controlNetId, changes } = action.payload;
|
||||||
const processorNode = state.controlNets[controlNetId].processorNode;
|
const cn = state.controlNets[controlNetId];
|
||||||
state.controlNets[controlNetId].processorNode = {
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const processorNode = cn.processorNode;
|
||||||
|
cn.processorNode = {
|
||||||
...processorNode,
|
...processorNode,
|
||||||
...changes,
|
...changes,
|
||||||
};
|
};
|
||||||
state.controlNets[controlNetId].shouldAutoConfig = false;
|
cn.shouldAutoConfig = false;
|
||||||
},
|
},
|
||||||
controlNetProcessorTypeChanged: (
|
controlNetProcessorTypeChanged: (
|
||||||
state,
|
state,
|
||||||
@ -255,12 +301,16 @@ export const controlNetSlice = createSlice({
|
|||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId, processorType } = action.payload;
|
const { controlNetId, processorType } = action.payload;
|
||||||
state.controlNets[controlNetId].processedControlImage = null;
|
const cn = state.controlNets[controlNetId];
|
||||||
state.controlNets[controlNetId].processorType = processorType;
|
if (!cn) {
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
return;
|
||||||
processorType
|
}
|
||||||
].default as RequiredControlNetProcessorNode;
|
|
||||||
state.controlNets[controlNetId].shouldAutoConfig = false;
|
cn.processedControlImage = null;
|
||||||
|
cn.processorType = processorType;
|
||||||
|
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||||
|
.default as RequiredControlNetProcessorNode;
|
||||||
|
cn.shouldAutoConfig = false;
|
||||||
},
|
},
|
||||||
controlNetAutoConfigToggled: (
|
controlNetAutoConfigToggled: (
|
||||||
state,
|
state,
|
||||||
@ -269,37 +319,36 @@ export const controlNetSlice = createSlice({
|
|||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { controlNetId } = action.payload;
|
const { controlNetId } = action.payload;
|
||||||
const newShouldAutoConfig =
|
const cn = state.controlNets[controlNetId];
|
||||||
!state.controlNets[controlNetId].shouldAutoConfig;
|
if (!cn) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newShouldAutoConfig = !cn.shouldAutoConfig;
|
||||||
|
|
||||||
if (newShouldAutoConfig) {
|
if (newShouldAutoConfig) {
|
||||||
// manage the processor for the user
|
// manage the processor for the user
|
||||||
let processorType: ControlNetProcessorType | undefined = undefined;
|
let processorType: ControlNetProcessorType | undefined = undefined;
|
||||||
|
|
||||||
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) {
|
||||||
if (
|
if (cn.model?.model_name.includes(modelSubstring)) {
|
||||||
state.controlNets[controlNetId].model?.model_name.includes(
|
|
||||||
modelSubstring
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
processorType = CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (processorType) {
|
if (processorType) {
|
||||||
state.controlNets[controlNetId].processorType = processorType;
|
cn.processorType = processorType;
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
cn.processorNode = CONTROLNET_PROCESSORS[processorType]
|
||||||
processorType
|
.default as RequiredControlNetProcessorNode;
|
||||||
].default as RequiredControlNetProcessorNode;
|
|
||||||
} else {
|
} else {
|
||||||
state.controlNets[controlNetId].processorType = 'none';
|
cn.processorType = 'none';
|
||||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS
|
cn.processorNode = CONTROLNET_PROCESSORS.none
|
||||||
.none.default as RequiredControlNetProcessorNode;
|
.default as RequiredControlNetProcessorNode;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
state.controlNets[controlNetId].shouldAutoConfig = newShouldAutoConfig;
|
cn.shouldAutoConfig = newShouldAutoConfig;
|
||||||
},
|
},
|
||||||
controlNetReset: () => {
|
controlNetReset: () => {
|
||||||
return { ...initialControlNetState };
|
return { ...initialControlNetState };
|
||||||
@ -307,9 +356,11 @@ export const controlNetSlice = createSlice({
|
|||||||
},
|
},
|
||||||
extraReducers: (builder) => {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(controlNetImageProcessed, (state, action) => {
|
builder.addCase(controlNetImageProcessed, (state, action) => {
|
||||||
if (
|
const cn = state.controlNets[action.payload.controlNetId];
|
||||||
state.controlNets[action.payload.controlNetId].controlImage !== null
|
if (!cn) {
|
||||||
) {
|
return;
|
||||||
|
}
|
||||||
|
if (cn.controlImage !== null) {
|
||||||
state.pendingControlImages.push(action.payload.controlNetId);
|
state.pendingControlImages.push(action.payload.controlNetId);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -15,30 +15,42 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice';
|
import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice';
|
||||||
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { some } from 'lodash-es';
|
||||||
import { ChangeEvent, memo, useCallback, useRef } from 'react';
|
import { ChangeEvent, memo, useCallback, useRef } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { imageDeletionConfirmed } from '../store/actions';
|
import { imageDeletionConfirmed } from '../store/actions';
|
||||||
import { selectImageUsage } from '../store/imageDeletionSelectors';
|
import { getImageUsage, selectImageUsage } from '../store/selectors';
|
||||||
import {
|
import { imageDeletionCanceled, isModalOpenChanged } from '../store/slice';
|
||||||
imageToDeleteCleared,
|
|
||||||
isModalOpenChanged,
|
|
||||||
} from '../store/imageDeletionSlice';
|
|
||||||
import ImageUsageMessage from './ImageUsageMessage';
|
import ImageUsageMessage from './ImageUsageMessage';
|
||||||
|
import { ImageUsage } from '../store/types';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector, selectImageUsage],
|
[stateSelector, selectImageUsage],
|
||||||
({ system, config, imageDeletion }, imageUsage) => {
|
(state, imagesUsage) => {
|
||||||
|
const { system, config, deleteImageModal } = state;
|
||||||
const { shouldConfirmOnDelete } = system;
|
const { shouldConfirmOnDelete } = system;
|
||||||
const { canRestoreDeletedImagesFromBin } = config;
|
const { canRestoreDeletedImagesFromBin } = config;
|
||||||
const { imageToDelete, isModalOpen } = imageDeletion;
|
const { imagesToDelete, isModalOpen } = deleteImageModal;
|
||||||
|
|
||||||
|
const allImageUsage = (imagesToDelete ?? []).map(({ image_name }) =>
|
||||||
|
getImageUsage(state, image_name)
|
||||||
|
);
|
||||||
|
|
||||||
|
const imageUsageSummary: ImageUsage = {
|
||||||
|
isInitialImage: some(allImageUsage, (i) => i.isInitialImage),
|
||||||
|
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
|
||||||
|
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
|
||||||
|
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
|
||||||
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
shouldConfirmOnDelete,
|
shouldConfirmOnDelete,
|
||||||
canRestoreDeletedImagesFromBin,
|
canRestoreDeletedImagesFromBin,
|
||||||
imageToDelete,
|
imagesToDelete,
|
||||||
imageUsage,
|
imagesUsage,
|
||||||
isModalOpen,
|
isModalOpen,
|
||||||
|
imageUsageSummary,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
@ -51,9 +63,10 @@ const DeleteImageModal = () => {
|
|||||||
const {
|
const {
|
||||||
shouldConfirmOnDelete,
|
shouldConfirmOnDelete,
|
||||||
canRestoreDeletedImagesFromBin,
|
canRestoreDeletedImagesFromBin,
|
||||||
imageToDelete,
|
imagesToDelete,
|
||||||
imageUsage,
|
imagesUsage,
|
||||||
isModalOpen,
|
isModalOpen,
|
||||||
|
imageUsageSummary,
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
|
|
||||||
const handleChangeShouldConfirmOnDelete = useCallback(
|
const handleChangeShouldConfirmOnDelete = useCallback(
|
||||||
@ -63,17 +76,19 @@ const DeleteImageModal = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const handleClose = useCallback(() => {
|
const handleClose = useCallback(() => {
|
||||||
dispatch(imageToDeleteCleared());
|
dispatch(imageDeletionCanceled());
|
||||||
dispatch(isModalOpenChanged(false));
|
dispatch(isModalOpenChanged(false));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleDelete = useCallback(() => {
|
const handleDelete = useCallback(() => {
|
||||||
if (!imageToDelete || !imageUsage) {
|
if (!imagesToDelete.length || !imagesUsage.length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(imageToDeleteCleared());
|
dispatch(imageDeletionCanceled());
|
||||||
dispatch(imageDeletionConfirmed({ imageDTO: imageToDelete, imageUsage }));
|
dispatch(
|
||||||
}, [dispatch, imageToDelete, imageUsage]);
|
imageDeletionConfirmed({ imageDTOs: imagesToDelete, imagesUsage })
|
||||||
|
);
|
||||||
|
}, [dispatch, imagesToDelete, imagesUsage]);
|
||||||
|
|
||||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
const cancelRef = useRef<HTMLButtonElement>(null);
|
||||||
|
|
||||||
@ -92,7 +107,7 @@ const DeleteImageModal = () => {
|
|||||||
|
|
||||||
<AlertDialogBody>
|
<AlertDialogBody>
|
||||||
<Flex direction="column" gap={3}>
|
<Flex direction="column" gap={3}>
|
||||||
<ImageUsageMessage imageUsage={imageUsage} />
|
<ImageUsageMessage imageUsage={imageUsageSummary} />
|
||||||
<Divider />
|
<Divider />
|
||||||
<Text>
|
<Text>
|
||||||
{canRestoreDeletedImagesFromBin
|
{canRestoreDeletedImagesFromBin
|
@ -3,6 +3,6 @@ import { ImageDTO } from 'services/api/types';
|
|||||||
import { ImageUsage } from './types';
|
import { ImageUsage } from './types';
|
||||||
|
|
||||||
export const imageDeletionConfirmed = createAction<{
|
export const imageDeletionConfirmed = createAction<{
|
||||||
imageDTO: ImageDTO;
|
imageDTOs: ImageDTO[];
|
||||||
imageUsage: ImageUsage;
|
imagesUsage: ImageUsage[];
|
||||||
}>('imageDeletion/imageDeletionConfirmed');
|
}>('deleteImageModal/imageDeletionConfirmed');
|
@ -0,0 +1,6 @@
|
|||||||
|
import { DeleteImageState } from './types';
|
||||||
|
|
||||||
|
export const initialDeleteImageState: DeleteImageState = {
|
||||||
|
imagesToDelete: [],
|
||||||
|
isModalOpen: false,
|
||||||
|
};
|
@ -39,17 +39,17 @@ export const getImageUsage = (state: RootState, image_name: string) => {
|
|||||||
export const selectImageUsage = createSelector(
|
export const selectImageUsage = createSelector(
|
||||||
[(state: RootState) => state],
|
[(state: RootState) => state],
|
||||||
(state) => {
|
(state) => {
|
||||||
const { imageToDelete } = state.imageDeletion;
|
const { imagesToDelete } = state.deleteImageModal;
|
||||||
|
|
||||||
if (!imageToDelete) {
|
if (!imagesToDelete.length) {
|
||||||
return;
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
const { image_name } = imageToDelete;
|
const imagesUsage = imagesToDelete.map((i) =>
|
||||||
|
getImageUsage(state, i.image_name)
|
||||||
|
);
|
||||||
|
|
||||||
const imageUsage = getImageUsage(state, image_name);
|
return imagesUsage;
|
||||||
|
|
||||||
return imageUsage;
|
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
@ -0,0 +1,28 @@
|
|||||||
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
import { initialDeleteImageState } from './initialState';
|
||||||
|
|
||||||
|
const deleteImageModal = createSlice({
|
||||||
|
name: 'deleteImageModal',
|
||||||
|
initialState: initialDeleteImageState,
|
||||||
|
reducers: {
|
||||||
|
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.isModalOpen = action.payload;
|
||||||
|
},
|
||||||
|
imagesToDeleteSelected: (state, action: PayloadAction<ImageDTO[]>) => {
|
||||||
|
state.imagesToDelete = action.payload;
|
||||||
|
},
|
||||||
|
imageDeletionCanceled: (state) => {
|
||||||
|
state.imagesToDelete = [];
|
||||||
|
state.isModalOpen = false;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
isModalOpenChanged,
|
||||||
|
imagesToDeleteSelected,
|
||||||
|
imageDeletionCanceled,
|
||||||
|
} = deleteImageModal.actions;
|
||||||
|
|
||||||
|
export default deleteImageModal.reducer;
|
@ -0,0 +1,13 @@
|
|||||||
|
import { ImageDTO } from 'services/api/types';
|
||||||
|
|
||||||
|
export type DeleteImageState = {
|
||||||
|
imagesToDelete: ImageDTO[];
|
||||||
|
isModalOpen: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ImageUsage = {
|
||||||
|
isInitialImage: boolean;
|
||||||
|
isCanvasImage: boolean;
|
||||||
|
isNodesImage: boolean;
|
||||||
|
isControlNetImage: boolean;
|
||||||
|
};
|
@ -11,11 +11,14 @@ import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
|||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ gallery }) => {
|
({ gallery, system }) => {
|
||||||
const { autoAddBoardId } = gallery;
|
const { autoAddBoardId, autoAssignBoardOnClick } = gallery;
|
||||||
|
const { isProcessing } = system;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
autoAddBoardId,
|
autoAddBoardId,
|
||||||
|
autoAssignBoardOnClick,
|
||||||
|
isProcessing,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
@ -23,7 +26,8 @@ const selector = createSelector(
|
|||||||
|
|
||||||
const BoardAutoAddSelect = () => {
|
const BoardAutoAddSelect = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { autoAddBoardId } = useAppSelector(selector);
|
const { autoAddBoardId, autoAssignBoardOnClick, isProcessing } =
|
||||||
|
useAppSelector(selector);
|
||||||
const inputRef = useRef<HTMLInputElement>(null);
|
const inputRef = useRef<HTMLInputElement>(null);
|
||||||
const { boards, hasBoards } = useListAllBoardsQuery(undefined, {
|
const { boards, hasBoards } = useListAllBoardsQuery(undefined, {
|
||||||
selectFromResult: ({ data }) => {
|
selectFromResult: ({ data }) => {
|
||||||
@ -52,7 +56,7 @@ const BoardAutoAddSelect = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(autoAddBoardIdChanged(v === 'none' ? undefined : v));
|
dispatch(autoAddBoardIdChanged(v));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
@ -67,7 +71,7 @@ const BoardAutoAddSelect = () => {
|
|||||||
data={boards}
|
data={boards}
|
||||||
nothingFound="No matching Boards"
|
nothingFound="No matching Boards"
|
||||||
itemComponent={IAIMantineSelectItemWithTooltip}
|
itemComponent={IAIMantineSelectItemWithTooltip}
|
||||||
disabled={!hasBoards}
|
disabled={!hasBoards || autoAssignBoardOnClick || isProcessing}
|
||||||
filter={(value, item: SelectItem) =>
|
filter={(value, item: SelectItem) =>
|
||||||
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
item.label?.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
item.value.toLowerCase().includes(value.toLowerCase().trim())
|
||||||
|
@ -11,10 +11,11 @@ import { BoardDTO } from 'services/api/types';
|
|||||||
import { menuListMotionProps } from 'theme/components/menu';
|
import { menuListMotionProps } from 'theme/components/menu';
|
||||||
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
|
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
|
||||||
import NoBoardContextMenuItems from './NoBoardContextMenuItems';
|
import NoBoardContextMenuItems from './NoBoardContextMenuItems';
|
||||||
|
import { BoardId } from 'features/gallery/store/types';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
board?: BoardDTO;
|
board?: BoardDTO;
|
||||||
board_id?: string;
|
board_id: BoardId;
|
||||||
children: ContextMenuProps<HTMLDivElement>['children'];
|
children: ContextMenuProps<HTMLDivElement>['children'];
|
||||||
setBoardToDelete?: (board?: BoardDTO) => void;
|
setBoardToDelete?: (board?: BoardDTO) => void;
|
||||||
};
|
};
|
||||||
@ -25,14 +26,17 @@ const BoardContextMenu = memo(
|
|||||||
|
|
||||||
const selector = useMemo(
|
const selector = useMemo(
|
||||||
() =>
|
() =>
|
||||||
createSelector(stateSelector, ({ gallery }) => {
|
createSelector(stateSelector, ({ gallery, system }) => {
|
||||||
const isAutoAdd = gallery.autoAddBoardId === board_id;
|
const isAutoAdd = gallery.autoAddBoardId === board_id;
|
||||||
return { isAutoAdd };
|
const isProcessing = system.isProcessing;
|
||||||
|
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
|
||||||
|
return { isAutoAdd, isProcessing, autoAssignBoardOnClick };
|
||||||
}),
|
}),
|
||||||
[board_id]
|
[board_id]
|
||||||
);
|
);
|
||||||
|
|
||||||
const { isAutoAdd } = useAppSelector(selector);
|
const { isAutoAdd, isProcessing, autoAssignBoardOnClick } =
|
||||||
|
useAppSelector(selector);
|
||||||
const boardName = useBoardName(board_id);
|
const boardName = useBoardName(board_id);
|
||||||
|
|
||||||
const handleSetAutoAdd = useCallback(() => {
|
const handleSetAutoAdd = useCallback(() => {
|
||||||
@ -59,7 +63,7 @@ const BoardContextMenu = memo(
|
|||||||
<MenuGroup title={boardName}>
|
<MenuGroup title={boardName}>
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<FaPlus />}
|
icon={<FaPlus />}
|
||||||
isDisabled={isAutoAdd}
|
isDisabled={isAutoAdd || isProcessing || autoAssignBoardOnClick}
|
||||||
onClick={handleSetAutoAdd}
|
onClick={handleSetAutoAdd}
|
||||||
>
|
>
|
||||||
Auto-add to this Board
|
Auto-add to this Board
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { AddToBatchDropData } from 'app/components/ImageDnd/typesafeDnd';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { FaLayerGroup } from 'react-icons/fa';
|
|
||||||
import { useDispatch } from 'react-redux';
|
|
||||||
import GenericBoard from './GenericBoard';
|
|
||||||
|
|
||||||
const selector = createSelector(stateSelector, (state) => {
|
|
||||||
return {
|
|
||||||
count: state.gallery.batchImageNames.length,
|
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
const BatchBoard = ({ isSelected }: { isSelected: boolean }) => {
|
|
||||||
const dispatch = useDispatch();
|
|
||||||
const { count } = useAppSelector(selector);
|
|
||||||
|
|
||||||
const handleBatchBoardClick = useCallback(() => {
|
|
||||||
dispatch(boardIdSelected('batch'));
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
const droppableData: AddToBatchDropData = {
|
|
||||||
id: 'batch-board',
|
|
||||||
actionType: 'ADD_TO_BATCH',
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<GenericBoard
|
|
||||||
board_id="batch"
|
|
||||||
droppableData={droppableData}
|
|
||||||
onClick={handleBatchBoardClick}
|
|
||||||
isSelected={isSelected}
|
|
||||||
icon={FaLayerGroup}
|
|
||||||
label="Batch"
|
|
||||||
badgeCount={count}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default BatchBoard;
|
|
@ -15,10 +15,9 @@ import NoBoardBoard from './NoBoardBoard';
|
|||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ boards, gallery }) => {
|
({ gallery }) => {
|
||||||
const { searchText } = boards;
|
const { selectedBoardId, boardSearchText } = gallery;
|
||||||
const { selectedBoardId } = gallery;
|
return { selectedBoardId, boardSearchText };
|
||||||
return { selectedBoardId, searchText };
|
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
@ -29,11 +28,11 @@ type Props = {
|
|||||||
|
|
||||||
const BoardsList = (props: Props) => {
|
const BoardsList = (props: Props) => {
|
||||||
const { isOpen } = props;
|
const { isOpen } = props;
|
||||||
const { selectedBoardId, searchText } = useAppSelector(selector);
|
const { selectedBoardId, boardSearchText } = useAppSelector(selector);
|
||||||
const { data: boards } = useListAllBoardsQuery();
|
const { data: boards } = useListAllBoardsQuery();
|
||||||
const filteredBoards = searchText
|
const filteredBoards = boardSearchText
|
||||||
? boards?.filter((board) =>
|
? boards?.filter((board) =>
|
||||||
board.board_name.toLowerCase().includes(searchText.toLowerCase())
|
board.board_name.toLowerCase().includes(boardSearchText.toLowerCase())
|
||||||
)
|
)
|
||||||
: boards;
|
: boards;
|
||||||
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
|
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
|
||||||
@ -75,7 +74,7 @@ const BoardsList = (props: Props) => {
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<GridItem sx={{ p: 1.5 }}>
|
<GridItem sx={{ p: 1.5 }}>
|
||||||
<NoBoardBoard isSelected={selectedBoardId === undefined} />
|
<NoBoardBoard isSelected={selectedBoardId === 'none'} />
|
||||||
</GridItem>
|
</GridItem>
|
||||||
{filteredBoards &&
|
{filteredBoards &&
|
||||||
filteredBoards.map((board) => (
|
filteredBoards.map((board) => (
|
||||||
|
@ -9,7 +9,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
|||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import { setBoardSearchText } from 'features/gallery/store/boardSlice';
|
import { boardSearchTextChanged } from 'features/gallery/store/gallerySlice';
|
||||||
import {
|
import {
|
||||||
ChangeEvent,
|
ChangeEvent,
|
||||||
KeyboardEvent,
|
KeyboardEvent,
|
||||||
@ -21,27 +21,27 @@ import {
|
|||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[stateSelector],
|
[stateSelector],
|
||||||
({ boards }) => {
|
({ gallery }) => {
|
||||||
const { searchText } = boards;
|
const { boardSearchText } = gallery;
|
||||||
return { searchText };
|
return { boardSearchText };
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const BoardsSearch = () => {
|
const BoardsSearch = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { searchText } = useAppSelector(selector);
|
const { boardSearchText } = useAppSelector(selector);
|
||||||
const inputRef = useRef<HTMLInputElement>(null);
|
const inputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
const handleBoardSearch = useCallback(
|
const handleBoardSearch = useCallback(
|
||||||
(searchTerm: string) => {
|
(searchTerm: string) => {
|
||||||
dispatch(setBoardSearchText(searchTerm));
|
dispatch(boardSearchTextChanged(searchTerm));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const clearBoardSearch = useCallback(() => {
|
const clearBoardSearch = useCallback(() => {
|
||||||
dispatch(setBoardSearchText(''));
|
dispatch(boardSearchTextChanged(''));
|
||||||
}, [dispatch]);
|
}, [dispatch]);
|
||||||
|
|
||||||
const handleKeydown = useCallback(
|
const handleKeydown = useCallback(
|
||||||
@ -74,11 +74,11 @@ const BoardsSearch = () => {
|
|||||||
<Input
|
<Input
|
||||||
ref={inputRef}
|
ref={inputRef}
|
||||||
placeholder="Search Boards..."
|
placeholder="Search Boards..."
|
||||||
value={searchText}
|
value={boardSearchText}
|
||||||
onKeyDown={handleKeydown}
|
onKeyDown={handleKeydown}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
/>
|
/>
|
||||||
{searchText && searchText.length && (
|
{boardSearchText && boardSearchText.length && (
|
||||||
<InputRightElement>
|
<InputRightElement>
|
||||||
<IconButton
|
<IconButton
|
||||||
onClick={clearBoardSearch}
|
onClick={clearBoardSearch}
|
||||||
|
@ -7,19 +7,27 @@ import {
|
|||||||
Icon,
|
Icon,
|
||||||
Image,
|
Image,
|
||||||
Text,
|
Text,
|
||||||
|
Tooltip,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { MoveBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAIDroppable from 'common/components/IAIDroppable';
|
import IAIDroppable from 'common/components/IAIDroppable';
|
||||||
import SelectionOverlay from 'common/components/SelectionOverlay';
|
import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||||
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
import {
|
||||||
|
autoAddBoardIdChanged,
|
||||||
|
boardIdSelected,
|
||||||
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { FaUser } from 'react-icons/fa';
|
import { FaUser } from 'react-icons/fa';
|
||||||
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
|
import {
|
||||||
|
useGetBoardAssetsTotalQuery,
|
||||||
|
useGetBoardImagesTotalQuery,
|
||||||
|
useUpdateBoardMutation,
|
||||||
|
} from 'services/api/endpoints/boards';
|
||||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||||
import { BoardDTO } from 'services/api/types';
|
import { BoardDTO } from 'services/api/types';
|
||||||
import AutoAddIcon from '../AutoAddIcon';
|
import AutoAddIcon from '../AutoAddIcon';
|
||||||
@ -38,18 +46,25 @@ const GalleryBoard = memo(
|
|||||||
() =>
|
() =>
|
||||||
createSelector(
|
createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ gallery }) => {
|
({ gallery, system }) => {
|
||||||
const isSelectedForAutoAdd =
|
const isSelectedForAutoAdd =
|
||||||
board.board_id === gallery.autoAddBoardId;
|
board.board_id === gallery.autoAddBoardId;
|
||||||
|
const autoAssignBoardOnClick = gallery.autoAssignBoardOnClick;
|
||||||
|
const isProcessing = system.isProcessing;
|
||||||
|
|
||||||
return { isSelectedForAutoAdd };
|
return {
|
||||||
|
isSelectedForAutoAdd,
|
||||||
|
autoAssignBoardOnClick,
|
||||||
|
isProcessing,
|
||||||
|
};
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
),
|
),
|
||||||
[board.board_id]
|
[board.board_id]
|
||||||
);
|
);
|
||||||
|
|
||||||
const { isSelectedForAutoAdd } = useAppSelector(selector);
|
const { isSelectedForAutoAdd, autoAssignBoardOnClick, isProcessing } =
|
||||||
|
useAppSelector(selector);
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
const handleMouseOver = useCallback(() => {
|
const handleMouseOver = useCallback(() => {
|
||||||
setIsHovered(true);
|
setIsHovered(true);
|
||||||
@ -57,6 +72,18 @@ const GalleryBoard = memo(
|
|||||||
const handleMouseOut = useCallback(() => {
|
const handleMouseOut = useCallback(() => {
|
||||||
setIsHovered(false);
|
setIsHovered(false);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
|
||||||
|
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
|
||||||
|
const tooltip = useMemo(() => {
|
||||||
|
if (!imagesTotal || !assetsTotal) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
return `${imagesTotal} image${
|
||||||
|
imagesTotal > 1 ? 's' : ''
|
||||||
|
}, ${assetsTotal} asset${assetsTotal > 1 ? 's' : ''}`;
|
||||||
|
}, [assetsTotal, imagesTotal]);
|
||||||
|
|
||||||
const { currentData: coverImage } = useGetImageDTOQuery(
|
const { currentData: coverImage } = useGetImageDTOQuery(
|
||||||
board.cover_image_name ?? skipToken
|
board.cover_image_name ?? skipToken
|
||||||
);
|
);
|
||||||
@ -66,15 +93,18 @@ const GalleryBoard = memo(
|
|||||||
|
|
||||||
const handleSelectBoard = useCallback(() => {
|
const handleSelectBoard = useCallback(() => {
|
||||||
dispatch(boardIdSelected(board_id));
|
dispatch(boardIdSelected(board_id));
|
||||||
}, [board_id, dispatch]);
|
if (autoAssignBoardOnClick && !isProcessing) {
|
||||||
|
dispatch(autoAddBoardIdChanged(board_id));
|
||||||
|
}
|
||||||
|
}, [board_id, autoAssignBoardOnClick, isProcessing, dispatch]);
|
||||||
|
|
||||||
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
|
const [updateBoard, { isLoading: isUpdateBoardLoading }] =
|
||||||
useUpdateBoardMutation();
|
useUpdateBoardMutation();
|
||||||
|
|
||||||
const droppableData: MoveBoardDropData = useMemo(
|
const droppableData: AddToBoardDropData = useMemo(
|
||||||
() => ({
|
() => ({
|
||||||
id: board_id,
|
id: board_id,
|
||||||
actionType: 'MOVE_BOARD',
|
actionType: 'ADD_TO_BOARD',
|
||||||
context: { boardId: board_id },
|
context: { boardId: board_id },
|
||||||
}),
|
}),
|
||||||
[board_id]
|
[board_id]
|
||||||
@ -135,6 +165,7 @@ const GalleryBoard = memo(
|
|||||||
setBoardToDelete={setBoardToDelete}
|
setBoardToDelete={setBoardToDelete}
|
||||||
>
|
>
|
||||||
{(ref) => (
|
{(ref) => (
|
||||||
|
<Tooltip label={tooltip} openDelay={1000} hasArrow>
|
||||||
<Flex
|
<Flex
|
||||||
ref={ref}
|
ref={ref}
|
||||||
onClick={handleSelectBoard}
|
onClick={handleSelectBoard}
|
||||||
@ -265,6 +296,7 @@ const GalleryBoard = memo(
|
|||||||
dropLabel={<Text fontSize="md">Move</Text>}
|
dropLabel={<Text fontSize="md">Move</Text>}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
</Tooltip>
|
||||||
)}
|
)}
|
||||||
</BoardContextMenu>
|
</BoardContextMenu>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -1,50 +1,60 @@
|
|||||||
import { Box, Flex, Image, Text } from '@chakra-ui/react';
|
import { Box, Flex, Image, Text } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { MoveBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import InvokeAILogoImage from 'assets/images/logo.png';
|
import InvokeAILogoImage from 'assets/images/logo.png';
|
||||||
import IAIDroppable from 'common/components/IAIDroppable';
|
import IAIDroppable from 'common/components/IAIDroppable';
|
||||||
import SelectionOverlay from 'common/components/SelectionOverlay';
|
import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||||
import { boardIdSelected } from 'features/gallery/store/gallerySlice';
|
import {
|
||||||
|
boardIdSelected,
|
||||||
|
autoAddBoardIdChanged,
|
||||||
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
import { memo, useCallback, useMemo, useState } from 'react';
|
||||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||||
import AutoAddIcon from '../AutoAddIcon';
|
import AutoAddIcon from '../AutoAddIcon';
|
||||||
import BoardContextMenu from '../BoardContextMenu';
|
import BoardContextMenu from '../BoardContextMenu';
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
({ gallery }) => {
|
({ gallery, system }) => {
|
||||||
const { autoAddBoardId } = gallery;
|
const { autoAddBoardId, autoAssignBoardOnClick } = gallery;
|
||||||
return { autoAddBoardId };
|
const { isProcessing } = system;
|
||||||
|
return { autoAddBoardId, autoAssignBoardOnClick, isProcessing };
|
||||||
},
|
},
|
||||||
defaultSelectorOptions
|
defaultSelectorOptions
|
||||||
);
|
);
|
||||||
|
|
||||||
const NoBoardBoard = memo(({ isSelected }: Props) => {
|
const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { autoAddBoardId } = useAppSelector(selector);
|
const { autoAddBoardId, autoAssignBoardOnClick, isProcessing } =
|
||||||
const boardName = useBoardName(undefined);
|
useAppSelector(selector);
|
||||||
|
const boardName = useBoardName('none');
|
||||||
const handleSelectBoard = useCallback(() => {
|
const handleSelectBoard = useCallback(() => {
|
||||||
dispatch(boardIdSelected(undefined));
|
dispatch(boardIdSelected('none'));
|
||||||
}, [dispatch]);
|
if (autoAssignBoardOnClick && !isProcessing) {
|
||||||
|
dispatch(autoAddBoardIdChanged('none'));
|
||||||
|
}
|
||||||
|
}, [dispatch, autoAssignBoardOnClick, isProcessing]);
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
const handleMouseOver = useCallback(() => {
|
const handleMouseOver = useCallback(() => {
|
||||||
setIsHovered(true);
|
setIsHovered(true);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const handleMouseOut = useCallback(() => {
|
const handleMouseOut = useCallback(() => {
|
||||||
setIsHovered(false);
|
setIsHovered(false);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const droppableData: MoveBoardDropData = useMemo(
|
const droppableData: RemoveFromBoardDropData = useMemo(
|
||||||
() => ({
|
() => ({
|
||||||
id: 'no_board',
|
id: 'no_board',
|
||||||
actionType: 'MOVE_BOARD',
|
actionType: 'REMOVE_FROM_BOARD',
|
||||||
context: { boardId: undefined },
|
|
||||||
}),
|
}),
|
||||||
[]
|
[]
|
||||||
);
|
);
|
||||||
@ -64,7 +74,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
|||||||
h: 'full',
|
h: 'full',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<BoardContextMenu>
|
<BoardContextMenu board_id="none">
|
||||||
{(ref) => (
|
{(ref) => (
|
||||||
<Flex
|
<Flex
|
||||||
ref={ref}
|
ref={ref}
|
||||||
@ -91,17 +101,6 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
|||||||
alignItems: 'center',
|
alignItems: 'center',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
{/* <Icon
|
|
||||||
boxSize={12}
|
|
||||||
as={FaBucket}
|
|
||||||
sx={{
|
|
||||||
opacity: 0.7,
|
|
||||||
color: 'base.500',
|
|
||||||
_dark: {
|
|
||||||
color: 'base.500',
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
/> */}
|
|
||||||
<Image
|
<Image
|
||||||
src={InvokeAILogoImage}
|
src={InvokeAILogoImage}
|
||||||
alt="invoke-ai-logo"
|
alt="invoke-ai-logo"
|
||||||
@ -117,19 +116,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
{/* <Flex
|
{autoAddBoardId === 'none' && <AutoAddIcon />}
|
||||||
sx={{
|
|
||||||
position: 'absolute',
|
|
||||||
insetInlineEnd: 0,
|
|
||||||
top: 0,
|
|
||||||
p: 1,
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Badge variant="solid" sx={BASE_BADGE_STYLES}>
|
|
||||||
{totalImages}/{totalAssets}
|
|
||||||
</Badge>
|
|
||||||
</Flex> */}
|
|
||||||
{!autoAddBoardId && <AutoAddIcon />}
|
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
position: 'absolute',
|
position: 'absolute',
|
||||||
|
@ -11,20 +11,20 @@ import {
|
|||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||||
import { ImageUsage } from 'app/contexts/AddImageToBoardContext';
|
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import ImageUsageMessage from 'features/imageDeletion/components/ImageUsageMessage';
|
import ImageUsageMessage from 'features/deleteImageModal/components/ImageUsageMessage';
|
||||||
import { getImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
|
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
|
||||||
|
import { ImageUsage } from 'features/deleteImageModal/store/types';
|
||||||
import { some } from 'lodash-es';
|
import { some } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useListAllImageNamesForBoardQuery } from 'services/api/endpoints/boards';
|
||||||
import {
|
import {
|
||||||
useDeleteBoardAndImagesMutation,
|
useDeleteBoardAndImagesMutation,
|
||||||
useDeleteBoardMutation,
|
useDeleteBoardMutation,
|
||||||
useListAllImageNamesForBoardQuery,
|
} from 'services/api/endpoints/images';
|
||||||
} from 'services/api/endpoints/boards';
|
|
||||||
import { BoardDTO } from 'services/api/types';
|
import { BoardDTO } from 'services/api/types';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
@ -32,7 +32,7 @@ type Props = {
|
|||||||
setBoardToDelete: (board?: BoardDTO) => void;
|
setBoardToDelete: (board?: BoardDTO) => void;
|
||||||
};
|
};
|
||||||
|
|
||||||
const DeleteImageModal = (props: Props) => {
|
const DeleteBoardModal = (props: Props) => {
|
||||||
const { boardToDelete, setBoardToDelete } = props;
|
const { boardToDelete, setBoardToDelete } = props;
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const canRestoreDeletedImagesFromBin = useAppSelector(
|
const canRestoreDeletedImagesFromBin = useAppSelector(
|
||||||
@ -49,13 +49,10 @@ const DeleteImageModal = (props: Props) => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const imageUsageSummary: ImageUsage = {
|
const imageUsageSummary: ImageUsage = {
|
||||||
isInitialImage: some(allImageUsage, (usage) => usage.isInitialImage),
|
isInitialImage: some(allImageUsage, (i) => i.isInitialImage),
|
||||||
isCanvasImage: some(allImageUsage, (usage) => usage.isCanvasImage),
|
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
|
||||||
isNodesImage: some(allImageUsage, (usage) => usage.isNodesImage),
|
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
|
||||||
isControlNetImage: some(
|
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
|
||||||
allImageUsage,
|
|
||||||
(usage) => usage.isControlNetImage
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
return { imageUsageSummary };
|
return { imageUsageSummary };
|
||||||
}),
|
}),
|
||||||
@ -176,4 +173,4 @@ const DeleteImageModal = (props: Props) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(DeleteImageModal);
|
export default memo(DeleteBoardModal);
|
||||||
|
@ -1,93 +0,0 @@
|
|||||||
import {
|
|
||||||
AlertDialog,
|
|
||||||
AlertDialogBody,
|
|
||||||
AlertDialogContent,
|
|
||||||
AlertDialogFooter,
|
|
||||||
AlertDialogHeader,
|
|
||||||
AlertDialogOverlay,
|
|
||||||
Box,
|
|
||||||
Flex,
|
|
||||||
Spinner,
|
|
||||||
Text,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import IAIButton from 'common/components/IAIButton';
|
|
||||||
|
|
||||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
|
||||||
import { memo, useContext, useRef, useState } from 'react';
|
|
||||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
|
||||||
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
|
|
||||||
|
|
||||||
const UpdateImageBoardModal = () => {
|
|
||||||
// const boards = useSelector(selectBoardsAll);
|
|
||||||
const { data: boards, isFetching } = useListAllBoardsQuery();
|
|
||||||
const { isOpen, onClose, handleAddToBoard, image } = useContext(
|
|
||||||
AddImageToBoardContext
|
|
||||||
);
|
|
||||||
const [selectedBoard, setSelectedBoard] = useState<string | null>();
|
|
||||||
|
|
||||||
const cancelRef = useRef<HTMLButtonElement>(null);
|
|
||||||
|
|
||||||
const currentBoard = boards?.find(
|
|
||||||
(board) => board.board_id === image?.board_id
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<AlertDialog
|
|
||||||
isOpen={isOpen}
|
|
||||||
leastDestructiveRef={cancelRef}
|
|
||||||
onClose={onClose}
|
|
||||||
isCentered
|
|
||||||
>
|
|
||||||
<AlertDialogOverlay>
|
|
||||||
<AlertDialogContent>
|
|
||||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
|
||||||
{currentBoard ? 'Move Image to Board' : 'Add Image to Board'}
|
|
||||||
</AlertDialogHeader>
|
|
||||||
|
|
||||||
<AlertDialogBody>
|
|
||||||
<Box>
|
|
||||||
<Flex direction="column" gap={3}>
|
|
||||||
{currentBoard && (
|
|
||||||
<Text>
|
|
||||||
Moving this image from{' '}
|
|
||||||
<strong>{currentBoard.board_name}</strong> to
|
|
||||||
</Text>
|
|
||||||
)}
|
|
||||||
{isFetching ? (
|
|
||||||
<Spinner />
|
|
||||||
) : (
|
|
||||||
<IAIMantineSearchableSelect
|
|
||||||
placeholder="Select Board"
|
|
||||||
onChange={(v) => setSelectedBoard(v)}
|
|
||||||
value={selectedBoard}
|
|
||||||
data={(boards ?? []).map((board) => ({
|
|
||||||
label: board.board_name,
|
|
||||||
value: board.board_id,
|
|
||||||
}))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
</Box>
|
|
||||||
</AlertDialogBody>
|
|
||||||
<AlertDialogFooter>
|
|
||||||
<IAIButton onClick={onClose}>Cancel</IAIButton>
|
|
||||||
<IAIButton
|
|
||||||
isDisabled={!selectedBoard}
|
|
||||||
colorScheme="accent"
|
|
||||||
onClick={() => {
|
|
||||||
if (selectedBoard) {
|
|
||||||
handleAddToBoard(selectedBoard);
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
ml={3}
|
|
||||||
>
|
|
||||||
{currentBoard ? 'Move' : 'Add'}
|
|
||||||
</IAIButton>
|
|
||||||
</AlertDialogFooter>
|
|
||||||
</AlertDialogContent>
|
|
||||||
</AlertDialogOverlay>
|
|
||||||
</AlertDialog>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(UpdateImageBoardModal);
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user