mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/lora_model_patch
This commit is contained in:
commit
ac46b129bf
32
.github/workflows/test-invoke-pip-skip.yml
vendored
32
.github/workflows/test-invoke-pip-skip.yml
vendored
@ -1,10 +1,16 @@
|
|||||||
name: Test invoke.py pip
|
name: Test invoke.py pip
|
||||||
|
|
||||||
|
# This is a dummy stand-in for the actual tests
|
||||||
|
# we don't need to run python tests on non-Python changes
|
||||||
|
# But PRs require passing tests to be mergeable
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- '**'
|
- '**'
|
||||||
- '!pyproject.toml'
|
- '!pyproject.toml'
|
||||||
- '!invokeai/**'
|
- '!invokeai/**'
|
||||||
|
- '!tests/**'
|
||||||
- 'invokeai/frontend/web/**'
|
- 'invokeai/frontend/web/**'
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
@ -19,48 +25,26 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version:
|
python-version:
|
||||||
# - '3.9'
|
|
||||||
- '3.10'
|
- '3.10'
|
||||||
pytorch:
|
pytorch:
|
||||||
# - linux-cuda-11_6
|
|
||||||
- linux-cuda-11_7
|
- linux-cuda-11_7
|
||||||
- linux-rocm-5_2
|
- linux-rocm-5_2
|
||||||
- linux-cpu
|
- linux-cpu
|
||||||
- macos-default
|
- macos-default
|
||||||
- windows-cpu
|
- windows-cpu
|
||||||
# - windows-cuda-11_6
|
|
||||||
# - windows-cuda-11_7
|
|
||||||
include:
|
include:
|
||||||
# - pytorch: linux-cuda-11_6
|
|
||||||
# os: ubuntu-22.04
|
|
||||||
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
|
|
||||||
# github-env: $GITHUB_ENV
|
|
||||||
- pytorch: linux-cuda-11_7
|
- pytorch: linux-cuda-11_7
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
github-env: $GITHUB_ENV
|
|
||||||
- pytorch: linux-rocm-5_2
|
- pytorch: linux-rocm-5_2
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
|
||||||
github-env: $GITHUB_ENV
|
|
||||||
- pytorch: linux-cpu
|
- pytorch: linux-cpu
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
|
||||||
github-env: $GITHUB_ENV
|
|
||||||
- pytorch: macos-default
|
- pytorch: macos-default
|
||||||
os: macOS-12
|
os: macOS-12
|
||||||
github-env: $GITHUB_ENV
|
|
||||||
- pytorch: windows-cpu
|
- pytorch: windows-cpu
|
||||||
os: windows-2022
|
os: windows-2022
|
||||||
github-env: $env:GITHUB_ENV
|
|
||||||
# - pytorch: windows-cuda-11_6
|
|
||||||
# os: windows-2022
|
|
||||||
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
|
|
||||||
# github-env: $env:GITHUB_ENV
|
|
||||||
# - pytorch: windows-cuda-11_7
|
|
||||||
# os: windows-2022
|
|
||||||
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
|
|
||||||
# github-env: $env:GITHUB_ENV
|
|
||||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
- run: 'echo "No build required"'
|
- name: skip
|
||||||
|
run: echo "no build required"
|
||||||
|
84
.github/workflows/test-invoke-pip.yml
vendored
84
.github/workflows/test-invoke-pip.yml
vendored
@ -11,6 +11,7 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
- 'invokeai/**'
|
- 'invokeai/**'
|
||||||
|
- 'tests/**'
|
||||||
- '!invokeai/frontend/web/**'
|
- '!invokeai/frontend/web/**'
|
||||||
types:
|
types:
|
||||||
- 'ready_for_review'
|
- 'ready_for_review'
|
||||||
@ -32,19 +33,12 @@ jobs:
|
|||||||
# - '3.9'
|
# - '3.9'
|
||||||
- '3.10'
|
- '3.10'
|
||||||
pytorch:
|
pytorch:
|
||||||
# - linux-cuda-11_6
|
|
||||||
- linux-cuda-11_7
|
- linux-cuda-11_7
|
||||||
- linux-rocm-5_2
|
- linux-rocm-5_2
|
||||||
- linux-cpu
|
- linux-cpu
|
||||||
- macos-default
|
- macos-default
|
||||||
- windows-cpu
|
- windows-cpu
|
||||||
# - windows-cuda-11_6
|
|
||||||
# - windows-cuda-11_7
|
|
||||||
include:
|
include:
|
||||||
# - pytorch: linux-cuda-11_6
|
|
||||||
# os: ubuntu-22.04
|
|
||||||
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
|
|
||||||
# github-env: $GITHUB_ENV
|
|
||||||
- pytorch: linux-cuda-11_7
|
- pytorch: linux-cuda-11_7
|
||||||
os: ubuntu-22.04
|
os: ubuntu-22.04
|
||||||
github-env: $GITHUB_ENV
|
github-env: $GITHUB_ENV
|
||||||
@ -62,14 +56,6 @@ jobs:
|
|||||||
- pytorch: windows-cpu
|
- pytorch: windows-cpu
|
||||||
os: windows-2022
|
os: windows-2022
|
||||||
github-env: $env:GITHUB_ENV
|
github-env: $env:GITHUB_ENV
|
||||||
# - pytorch: windows-cuda-11_6
|
|
||||||
# os: windows-2022
|
|
||||||
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
|
|
||||||
# github-env: $env:GITHUB_ENV
|
|
||||||
# - pytorch: windows-cuda-11_7
|
|
||||||
# os: windows-2022
|
|
||||||
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
|
|
||||||
# github-env: $env:GITHUB_ENV
|
|
||||||
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
env:
|
env:
|
||||||
@ -100,40 +86,38 @@ jobs:
|
|||||||
id: run-pytest
|
id: run-pytest
|
||||||
run: pytest
|
run: pytest
|
||||||
|
|
||||||
- name: run invokeai-configure
|
# - name: run invokeai-configure
|
||||||
id: run-preload-models
|
# env:
|
||||||
env:
|
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
|
||||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
|
# run: >
|
||||||
run: >
|
# invokeai-configure
|
||||||
invokeai-configure
|
# --yes
|
||||||
--yes
|
# --default_only
|
||||||
--default_only
|
# --full-precision
|
||||||
--full-precision
|
# # can't use fp16 weights without a GPU
|
||||||
# can't use fp16 weights without a GPU
|
|
||||||
|
|
||||||
- name: run invokeai
|
# - name: run invokeai
|
||||||
id: run-invokeai
|
# id: run-invokeai
|
||||||
env:
|
# env:
|
||||||
# Set offline mode to make sure configure preloaded successfully.
|
# # Set offline mode to make sure configure preloaded successfully.
|
||||||
HF_HUB_OFFLINE: 1
|
# HF_HUB_OFFLINE: 1
|
||||||
HF_DATASETS_OFFLINE: 1
|
# HF_DATASETS_OFFLINE: 1
|
||||||
TRANSFORMERS_OFFLINE: 1
|
# TRANSFORMERS_OFFLINE: 1
|
||||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||||
run: >
|
# run: >
|
||||||
invokeai
|
# invokeai
|
||||||
--no-patchmatch
|
# --no-patchmatch
|
||||||
--no-nsfw_checker
|
# --no-nsfw_checker
|
||||||
--precision=float32
|
# --precision=float32
|
||||||
--always_use_cpu
|
# --always_use_cpu
|
||||||
--use_memory_db
|
# --use_memory_db
|
||||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||||
--from_file ${{ env.TEST_PROMPTS }}
|
# --from_file ${{ env.TEST_PROMPTS }}
|
||||||
|
|
||||||
- name: Archive results
|
# - name: Archive results
|
||||||
id: archive-results
|
# env:
|
||||||
env:
|
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
||||||
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
|
# uses: actions/upload-artifact@v3
|
||||||
uses: actions/upload-artifact@v3
|
# with:
|
||||||
with:
|
# name: results
|
||||||
name: results
|
# path: ${{ env.INVOKEAI_OUTDIR }}
|
||||||
path: ${{ env.INVOKEAI_OUTDIR }}
|
|
||||||
|
@ -87,18 +87,18 @@ Prior to installing PyPatchMatch, you need to take the following steps:
|
|||||||
sudo pacman -S --needed base-devel
|
sudo pacman -S --needed base-devel
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install `opencv`:
|
2. Install `opencv` and `blas`:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
sudo pacman -S opencv
|
sudo pacman -S opencv blas
|
||||||
```
|
```
|
||||||
|
|
||||||
or for CUDA support
|
or for CUDA support
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
sudo pacman -S opencv-cuda
|
sudo pacman -S opencv-cuda blas
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Fix the naming of the `opencv` package configuration file:
|
3. Fix the naming of the `opencv` package configuration file:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
|
@ -38,6 +38,7 @@ echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist
|
|||||||
echo.
|
echo.
|
||||||
echo See %INSTRUCTIONS% for more details.
|
echo See %INSTRUCTIONS% for more details.
|
||||||
echo.
|
echo.
|
||||||
|
echo "For the best user experience we suggest enlarging or maximizing this window now."
|
||||||
pause
|
pause
|
||||||
|
|
||||||
@rem ---------------------------- check Python version ---------------
|
@rem ---------------------------- check Python version ---------------
|
||||||
|
@ -26,6 +26,7 @@ done
|
|||||||
if [ -z "$PYTHON" ]; then
|
if [ -z "$PYTHON" ]; then
|
||||||
echo "A suitable Python interpreter could not be found"
|
echo "A suitable Python interpreter could not be found"
|
||||||
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
|
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
|
||||||
|
echo "For the best user experience we suggest enlarging or maximizing this window now."
|
||||||
read -p "Press any key to exit"
|
read -p "Press any key to exit"
|
||||||
exit -1
|
exit -1
|
||||||
fi
|
fi
|
||||||
|
@ -293,6 +293,8 @@ def introduction() -> None:
|
|||||||
"3. Create initial configuration files.",
|
"3. Create initial configuration files.",
|
||||||
"",
|
"",
|
||||||
"[i]At any point you may interrupt this program and resume later.",
|
"[i]At any point you may interrupt this program and resume later.",
|
||||||
|
"",
|
||||||
|
"[b]For the best user experience, please enlarge or maximize this window",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
# InvokeAI nodes for ControlNet image preprocessors
|
# Invocations for ControlNet image preprocessors
|
||||||
# initial implementation by Gregg Helt, 2023
|
# initial implementation by Gregg Helt, 2023
|
||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
from builtins import float
|
from builtins import float, bool
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Literal, Optional, Union, List
|
from typing import Literal, Optional, Union, List, Dict
|
||||||
from PIL import Image, ImageFilter, ImageOps
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
@ -29,8 +30,13 @@ from controlnet_aux import (
|
|||||||
ContentShuffleDetector,
|
ContentShuffleDetector,
|
||||||
ZoeDetector,
|
ZoeDetector,
|
||||||
MediapipeFaceDetector,
|
MediapipeFaceDetector,
|
||||||
|
SamDetector,
|
||||||
|
LeresDetector,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from controlnet_aux.util import HWC3, ade_palette
|
||||||
|
|
||||||
|
|
||||||
from .image import ImageOutput, PILInvocationConfig
|
from .image import ImageOutput, PILInvocationConfig
|
||||||
|
|
||||||
CONTROLNET_DEFAULT_MODELS = [
|
CONTROLNET_DEFAULT_MODELS = [
|
||||||
@ -94,6 +100,10 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||||
|
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||||
|
# crop and fill options not ready yet
|
||||||
|
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
|
||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
@ -104,6 +114,9 @@ class ControlField(BaseModel):
|
|||||||
description="When the ControlNet is first applied (% of total steps)")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||||
|
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
@validator("control_weight")
|
@validator("control_weight")
|
||||||
def abs_le_one(cls, v):
|
def abs_le_one(cls, v):
|
||||||
"""validate that all abs(values) are <=1"""
|
"""validate that all abs(values) are <=1"""
|
||||||
@ -144,11 +157,11 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
|
||||||
description="control model used")
|
description="control model used")
|
||||||
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
|
||||||
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
|
|
||||||
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
begin_step_percent: float = Field(default=0, ge=0, le=1,
|
||||||
description="When the ControlNet is first applied (% of total steps)")
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(default=1, ge=0, le=1,
|
end_step_percent: float = Field(default=1, ge=0, le=1,
|
||||||
description="When the ControlNet is last applied (% of total steps)")
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -166,7 +179,6 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||||
|
|
||||||
return ControlOutput(
|
return ControlOutput(
|
||||||
control=ControlField(
|
control=ControlField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
@ -174,10 +186,11 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
control_weight=self.control_weight,
|
control_weight=self.control_weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
|
control_mode=self.control_mode,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: move image processors to separate file (image_analysis.py
|
|
||||||
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Base class for invocations that preprocess images for ControlNet"""
|
"""Base class for invocations that preprocess images for ControlNet"""
|
||||||
|
|
||||||
@ -449,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
|
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||||
|
# so convert to RGB if needed
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
image = image.convert('RGB')
|
||||||
mediapipe_face_processor = MediapipeFaceDetector()
|
mediapipe_face_processor = MediapipeFaceDetector()
|
||||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies leres processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||||
|
# Inputs
|
||||||
|
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
|
||||||
|
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
|
||||||
|
boost: bool = Field(default=False, description="Whether to use boost mode")
|
||||||
|
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
|
||||||
|
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
|
processed_image = leres_processor(image,
|
||||||
|
thr_a=self.thr_a,
|
||||||
|
thr_b=self.thr_b,
|
||||||
|
boost=self.boost,
|
||||||
|
detect_resolution=self.detect_resolution,
|
||||||
|
image_resolution=self.image_resolution)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||||
|
# Inputs
|
||||||
|
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
|
||||||
|
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||||
|
def tile_resample(self,
|
||||||
|
np_img: np.ndarray,
|
||||||
|
res=512, # never used?
|
||||||
|
down_sampling_rate=1.0,
|
||||||
|
):
|
||||||
|
np_img = HWC3(np_img)
|
||||||
|
if down_sampling_rate < 1.1:
|
||||||
|
return np_img
|
||||||
|
H, W, C = np_img.shape
|
||||||
|
H = int(float(H) / float(down_sampling_rate))
|
||||||
|
W = int(float(W) / float(down_sampling_rate))
|
||||||
|
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||||
|
return np_img
|
||||||
|
|
||||||
|
def run_processor(self, img):
|
||||||
|
np_img = np.array(img, dtype=np.uint8)
|
||||||
|
processed_np_image = self.tile_resample(np_img,
|
||||||
|
#res=self.tile_size,
|
||||||
|
down_sampling_rate=self.down_sampling_rate
|
||||||
|
)
|
||||||
|
processed_image = Image.fromarray(processed_np_image)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
"""Applies segment anything processing to image"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def run_processor(self, image):
|
||||||
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
|
np_img = np.array(image, dtype=np.uint8)
|
||||||
|
processed_image = segment_anything_processor(np_img)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
class SamDetectorReproducibleColors(SamDetector):
|
||||||
|
|
||||||
|
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||||
|
# base class show_anns() method randomizes colors,
|
||||||
|
# which seems to also lead to non-reproducible image generation
|
||||||
|
# so using ADE20k color palette instead
|
||||||
|
def show_anns(self, anns: List[Dict]):
|
||||||
|
if len(anns) == 0:
|
||||||
|
return
|
||||||
|
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||||
|
h, w = anns[0]['segmentation'].shape
|
||||||
|
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||||
|
palette = ade_palette()
|
||||||
|
for i, ann in enumerate(sorted_anns):
|
||||||
|
m = ann['segmentation']
|
||||||
|
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||||
|
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||||
|
ann_color = palette[i % len(palette)]
|
||||||
|
img[:, :] = ann_color
|
||||||
|
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
||||||
|
return np.array(final_img, dtype=np.uint8)
|
||||||
|
@ -23,7 +23,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||||
PostprocessingSettings
|
PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
InvocationConfig, InvocationContext)
|
InvocationConfig, InvocationContext)
|
||||||
@ -59,31 +59,12 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
|
|||||||
height=latents.size()[2] * 8,
|
height=latents.size()[2] * 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
class NoiseOutput(BaseInvocationOutput):
|
|
||||||
"""Invocation noise output"""
|
|
||||||
#fmt: off
|
|
||||||
type: Literal["noise_output"] = "noise_output"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
noise: LatentsField = Field(default=None, description="The output noise")
|
|
||||||
width: int = Field(description="The width of the noise in pixels")
|
|
||||||
height: int = Field(description="The height of the noise in pixels")
|
|
||||||
#fmt: on
|
|
||||||
|
|
||||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
|
||||||
return NoiseOutput(
|
|
||||||
noise=LatentsField(latents_name=latents_name),
|
|
||||||
width=latents.size()[3] * 8,
|
|
||||||
height=latents.size()[2] * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
tuple(list(SCHEDULER_MAP.keys()))
|
tuple(list(SCHEDULER_MAP.keys()))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
@ -105,62 +86,6 @@ def get_scheduler(
|
|||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
|
|
||||||
# limit noise to only the diffusion image channels, not the mask channels
|
|
||||||
input_channels = min(latent_channels, 4)
|
|
||||||
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
|
|
||||||
generator = torch.Generator(device=use_device).manual_seed(seed)
|
|
||||||
x = torch.randn(
|
|
||||||
[
|
|
||||||
1,
|
|
||||||
input_channels,
|
|
||||||
height // downsampling_factor,
|
|
||||||
width // downsampling_factor,
|
|
||||||
],
|
|
||||||
dtype=torch_dtype(device),
|
|
||||||
device=use_device,
|
|
||||||
generator=generator,
|
|
||||||
).to(device)
|
|
||||||
# if self.perlin > 0.0:
|
|
||||||
# perlin_noise = self.get_perlin_noise(
|
|
||||||
# width // self.downsampling_factor, height // self.downsampling_factor
|
|
||||||
# )
|
|
||||||
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
|
||||||
return x
|
|
||||||
|
|
||||||
class NoiseInvocation(BaseInvocation):
|
|
||||||
"""Generates latent noise."""
|
|
||||||
|
|
||||||
type: Literal["noise"] = "noise"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
|
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents", "noise"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@validator("seed", pre=True)
|
|
||||||
def modulo_seed(cls, v):
|
|
||||||
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
|
||||||
return v % SEED_MAX
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
noise = get_noise(self.width, self.height, device, self.seed)
|
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
|
||||||
context.services.latents.save(name, noise)
|
|
||||||
return build_noise_output(latents_name=name, latents=noise)
|
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class TextToLatentsInvocation(BaseInvocation):
|
class TextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
@ -287,19 +212,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
control_height_resize = latents_shape[2] * 8
|
control_height_resize = latents_shape[2] * 8
|
||||||
control_width_resize = latents_shape[3] * 8
|
control_width_resize = latents_shape[3] * 8
|
||||||
if control_input is None:
|
if control_input is None:
|
||||||
# print("control input is None")
|
|
||||||
control_list = None
|
control_list = None
|
||||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||||
# print("control input is empty list")
|
|
||||||
control_list = None
|
control_list = None
|
||||||
elif isinstance(control_input, ControlField):
|
elif isinstance(control_input, ControlField):
|
||||||
# print("control input is ControlField")
|
|
||||||
control_list = [control_input]
|
control_list = [control_input]
|
||||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||||
# print("control input is list[ControlField]")
|
|
||||||
control_list = control_input
|
control_list = control_input
|
||||||
else:
|
else:
|
||||||
# print("input control is unrecognized:", type(self.control))
|
|
||||||
control_list = None
|
control_list = None
|
||||||
if (control_list is None):
|
if (control_list is None):
|
||||||
control_data = None
|
control_data = None
|
||||||
@ -341,12 +261,15 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# num_images_per_prompt=num_images_per_prompt,
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
device=control_model.device,
|
device=control_model.device,
|
||||||
dtype=control_model.dtype,
|
dtype=control_model.dtype,
|
||||||
|
control_mode=control_info.control_mode,
|
||||||
)
|
)
|
||||||
control_item = ControlNetData(model=control_model,
|
control_item = ControlNetData(model=control_model,
|
||||||
image_tensor=control_image,
|
image_tensor=control_image,
|
||||||
weight=control_info.control_weight,
|
weight=control_info.control_weight,
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
end_step_percent=control_info.end_step_percent)
|
end_step_percent=control_info.end_step_percent,
|
||||||
|
control_mode=control_info.control_mode,
|
||||||
|
)
|
||||||
control_data.append(control_item)
|
control_data.append(control_item)
|
||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
return control_data
|
return control_data
|
||||||
|
134
invokeai/app/invocations/noise.py
Normal file
134
invokeai/app/invocations/noise.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import Field, validator
|
||||||
|
import torch
|
||||||
|
from invokeai.app.invocations.latent import LatentsField
|
||||||
|
|
||||||
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvocationConfig,
|
||||||
|
InvocationContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Utilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_noise(
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
device: torch.device,
|
||||||
|
seed: int = 0,
|
||||||
|
latent_channels: int = 4,
|
||||||
|
downsampling_factor: int = 8,
|
||||||
|
use_cpu: bool = True,
|
||||||
|
perlin: float = 0.0,
|
||||||
|
):
|
||||||
|
"""Generate noise for a given image size."""
|
||||||
|
noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type
|
||||||
|
|
||||||
|
# limit noise to only the diffusion image channels, not the mask channels
|
||||||
|
input_channels = min(latent_channels, 4)
|
||||||
|
generator = torch.Generator(device=noise_device_type).manual_seed(seed)
|
||||||
|
|
||||||
|
noise_tensor = torch.randn(
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
input_channels,
|
||||||
|
height // downsampling_factor,
|
||||||
|
width // downsampling_factor,
|
||||||
|
],
|
||||||
|
dtype=torch_dtype(device),
|
||||||
|
device=noise_device_type,
|
||||||
|
generator=generator,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
return noise_tensor
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Nodes
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
|
"""Invocation noise output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["noise_output"] = "noise_output"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
noise: LatentsField = Field(default=None, description="The output noise")
|
||||||
|
width: int = Field(description="The width of the noise in pixels")
|
||||||
|
height: int = Field(description="The height of the noise in pixels")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
||||||
|
return NoiseOutput(
|
||||||
|
noise=LatentsField(latents_name=latents_name),
|
||||||
|
width=latents.size()[3] * 8,
|
||||||
|
height=latents.size()[2] * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseInvocation(BaseInvocation):
|
||||||
|
"""Generates latent noise."""
|
||||||
|
|
||||||
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
seed: int = Field(
|
||||||
|
ge=0,
|
||||||
|
le=SEED_MAX,
|
||||||
|
description="The seed to use",
|
||||||
|
default_factory=get_random_seed,
|
||||||
|
)
|
||||||
|
width: int = Field(
|
||||||
|
default=512,
|
||||||
|
multiple_of=8,
|
||||||
|
gt=0,
|
||||||
|
description="The width of the resulting noise",
|
||||||
|
)
|
||||||
|
height: int = Field(
|
||||||
|
default=512,
|
||||||
|
multiple_of=8,
|
||||||
|
gt=0,
|
||||||
|
description="The height of the resulting noise",
|
||||||
|
)
|
||||||
|
use_cpu: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Use CPU for noise generation (for reproducible results across platforms)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents", "noise"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@validator("seed", pre=True)
|
||||||
|
def modulo_seed(cls, v):
|
||||||
|
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
||||||
|
return v % SEED_MAX
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
|
noise = get_noise(
|
||||||
|
width=self.width,
|
||||||
|
height=self.height,
|
||||||
|
device=choose_torch_device(),
|
||||||
|
seed=self.seed,
|
||||||
|
use_cpu=self.use_cpu,
|
||||||
|
)
|
||||||
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
|
context.services.latents.save(name, noise)
|
||||||
|
return build_noise_output(latents_name=name, latents=noise)
|
@ -133,20 +133,19 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
postlist = list(num_poststeps * [self.post_end_value])
|
postlist = list(num_poststeps * [self.post_end_value])
|
||||||
|
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
logger = InvokeAILogger.getLogger(name="StepParamEasing")
|
context.services.logger.debug("start_step: " + str(start_step))
|
||||||
logger.debug("start_step: " + str(start_step))
|
context.services.logger.debug("end_step: " + str(end_step))
|
||||||
logger.debug("end_step: " + str(end_step))
|
context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||||
logger.debug("num_easing_steps: " + str(num_easing_steps))
|
context.services.logger.debug("num_presteps: " + str(num_presteps))
|
||||||
logger.debug("num_presteps: " + str(num_presteps))
|
context.services.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||||
logger.debug("num_poststeps: " + str(num_poststeps))
|
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
||||||
logger.debug("prelist size: " + str(len(prelist)))
|
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
||||||
logger.debug("postlist size: " + str(len(postlist)))
|
context.services.logger.debug("prelist: " + str(prelist))
|
||||||
logger.debug("prelist: " + str(prelist))
|
context.services.logger.debug("postlist: " + str(postlist))
|
||||||
logger.debug("postlist: " + str(postlist))
|
|
||||||
|
|
||||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
logger.debug("easing class: " + str(easing_class))
|
context.services.logger.debug("easing class: " + str(easing_class))
|
||||||
easing_list = list()
|
easing_list = list()
|
||||||
if self.mirror: # "expected" mirroring
|
if self.mirror: # "expected" mirroring
|
||||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||||
@ -156,7 +155,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||||
|
|
||||||
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||||
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration))
|
if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||||
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||||
easing_function = easing_class(start=self.start_value,
|
easing_function = easing_class(start=self.start_value,
|
||||||
end=self.end_value,
|
end=self.end_value,
|
||||||
@ -166,14 +165,14 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
easing_val = easing_function.ease(step_index)
|
easing_val = easing_function.ease(step_index)
|
||||||
base_easing_vals.append(easing_val)
|
base_easing_vals.append(easing_val)
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||||
if even_num_steps:
|
if even_num_steps:
|
||||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||||
else:
|
else:
|
||||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
logger.debug("base easing vals: " + str(base_easing_vals))
|
context.services.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||||
logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||||
easing_list = base_easing_vals + mirror_easing_vals
|
easing_list = base_easing_vals + mirror_easing_vals
|
||||||
|
|
||||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||||
@ -206,12 +205,12 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
step_val = easing_function.ease(step_index)
|
step_val = easing_function.ease(step_index)
|
||||||
easing_list.append(step_val)
|
easing_list.append(step_val)
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||||
|
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
logger.debug("prelist size: " + str(len(prelist)))
|
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
||||||
logger.debug("easing_list size: " + str(len(easing_list)))
|
context.services.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||||
logger.debug("postlist size: " + str(len(postlist)))
|
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
||||||
|
|
||||||
param_list = prelist + easing_list + postlist
|
param_list = prelist + easing_list + postlist
|
||||||
|
|
||||||
|
@ -374,8 +374,10 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
|
||||||
|
|
||||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||||
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
autoimport_dir : Path = Field(default='autoimport/main', description='Path to a directory of models files to be imported on startup.', category='Paths')
|
||||||
autoconvert_dir : Path = Field(default=None, description='Deprecated configuration option.', category='Paths')
|
lora_dir : Path = Field(default='autoimport/lora', description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
|
||||||
|
embedding_dir : Path = Field(default='autoimport/embedding', description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
|
||||||
|
controlnet_dir : Path = Field(default='autoimport/controlnet', description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
|
||||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||||
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
|
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
|
||||||
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
|
||||||
|
from ..invocations.noise import NoiseInvocation
|
||||||
from ..invocations.compel import CompelInvocation
|
from ..invocations.compel import CompelInvocation
|
||||||
from ..invocations.params import ParamIntInvocation
|
from ..invocations.params import ParamIntInvocation
|
||||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||||
|
@ -7,8 +7,6 @@
|
|||||||
# Coauthor: Kevin Turner http://github.com/keturn
|
# Coauthor: Kevin Turner http://github.com/keturn
|
||||||
#
|
#
|
||||||
import sys
|
import sys
|
||||||
print("Loading Python libraries...\n",file=sys.stderr)
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
@ -442,6 +440,26 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
|
|||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
|
self.add_widget_intelligent(
|
||||||
|
npyscreen.FixedText,
|
||||||
|
value="Directories containing textual inversion, controlnet and LoRA models (<tab> autocompletes, ctrl-N advances):",
|
||||||
|
editable=False,
|
||||||
|
color="CONTROL",
|
||||||
|
)
|
||||||
|
self.autoimport_dirs = {}
|
||||||
|
for description, config_name, path in autoimport_paths(old_opts):
|
||||||
|
self.autoimport_dirs[config_name] = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleFilename,
|
||||||
|
name=description+':',
|
||||||
|
value=str(path),
|
||||||
|
select_dir=True,
|
||||||
|
must_exist=False,
|
||||||
|
use_two_lines=False,
|
||||||
|
labelColor="GOOD",
|
||||||
|
begin_entry_at=32,
|
||||||
|
scroll_exit=True
|
||||||
|
)
|
||||||
|
self.nextrely += 1
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
name="== LICENSE ==",
|
name="== LICENSE ==",
|
||||||
@ -505,10 +523,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
|||||||
bad_fields.append(
|
bad_fields.append(
|
||||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||||
)
|
)
|
||||||
# if not Path(opt.embedding_dir).parent.exists():
|
|
||||||
# bad_fields.append(
|
|
||||||
# f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
|
|
||||||
# )
|
|
||||||
if len(bad_fields) > 0:
|
if len(bad_fields) > 0:
|
||||||
message = "The following problems were detected and must be corrected:\n"
|
message = "The following problems were detected and must be corrected:\n"
|
||||||
for problem in bad_fields:
|
for problem in bad_fields:
|
||||||
@ -528,12 +542,15 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
|
|||||||
"max_loaded_models",
|
"max_loaded_models",
|
||||||
"xformers_enabled",
|
"xformers_enabled",
|
||||||
"always_use_cpu",
|
"always_use_cpu",
|
||||||
# "embedding_dir",
|
|
||||||
# "lora_dir",
|
|
||||||
# "controlnet_dir",
|
|
||||||
]:
|
]:
|
||||||
setattr(new_opts, attr, getattr(self, attr).value)
|
setattr(new_opts, attr, getattr(self, attr).value)
|
||||||
|
|
||||||
|
for attr in self.autoimport_dirs:
|
||||||
|
directory = Path(self.autoimport_dirs[attr].value)
|
||||||
|
if directory.is_relative_to(config.root_path):
|
||||||
|
directory = directory.relative_to(config.root_path)
|
||||||
|
setattr(new_opts, attr, directory)
|
||||||
|
|
||||||
new_opts.hf_token = self.hf_token.value
|
new_opts.hf_token = self.hf_token.value
|
||||||
new_opts.license_acceptance = self.license_acceptance.value
|
new_opts.license_acceptance = self.license_acceptance.value
|
||||||
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
|
||||||
@ -595,22 +612,32 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
|||||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||||
if program_opts.yes_to_all
|
if program_opts.yes_to_all
|
||||||
else list(),
|
else list(),
|
||||||
scan_directory=None,
|
# scan_directory=None,
|
||||||
autoscan_on_startup=None,
|
# autoscan_on_startup=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# -------------------------------------
|
||||||
|
def autoimport_paths(config: InvokeAIAppConfig):
|
||||||
|
return [
|
||||||
|
('Checkpoints & diffusers models', 'autoimport_dir', config.root_path / config.autoimport_dir),
|
||||||
|
('LoRA/LyCORIS models', 'lora_dir', config.root_path / config.lora_dir),
|
||||||
|
('Controlnet models', 'controlnet_dir', config.root_path / config.controlnet_dir),
|
||||||
|
('Textual Inversion Embeddings', 'embedding_dir', config.root_path / config.embedding_dir),
|
||||||
|
]
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||||
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
|
||||||
for name in (
|
for name in (
|
||||||
"models",
|
"models",
|
||||||
"databases",
|
"databases",
|
||||||
"autoimport",
|
|
||||||
"text-inversion-output",
|
"text-inversion-output",
|
||||||
"text-inversion-training-data",
|
"text-inversion-training-data",
|
||||||
"configs"
|
"configs"
|
||||||
):
|
):
|
||||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||||
|
for model_type in ModelType:
|
||||||
|
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
configs_src = Path(configs.__path__[0])
|
configs_src = Path(configs.__path__[0])
|
||||||
configs_dest = root / "configs"
|
configs_dest = root / "configs"
|
||||||
@ -618,9 +645,8 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
|||||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||||
|
|
||||||
dest = root / 'models'
|
dest = root / 'models'
|
||||||
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
|
for model_base in BaseModelType:
|
||||||
for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora,
|
for model_type in ModelType:
|
||||||
ModelType.ControlNet,ModelType.TextualInversion]:
|
|
||||||
path = dest / model_base.value / model_type.value
|
path = dest / model_base.value / model_type.value
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
path = dest / 'core'
|
path = dest / 'core'
|
||||||
@ -632,9 +658,7 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# with open(root / 'invokeai.yaml','w') as f:
|
|
||||||
# f.write('#empty invokeai.yaml initialization file')
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def run_console_ui(
|
def run_console_ui(
|
||||||
program_opts: Namespace, initfile: Path = None
|
program_opts: Namespace, initfile: Path = None
|
||||||
@ -680,18 +704,6 @@ def write_opts(opts: Namespace, init_file: Path):
|
|||||||
def default_output_dir() -> Path:
|
def default_output_dir() -> Path:
|
||||||
return config.root_path / "outputs"
|
return config.root_path / "outputs"
|
||||||
|
|
||||||
# # -------------------------------------
|
|
||||||
# def default_embedding_dir() -> Path:
|
|
||||||
# return config.root_path / "embeddings"
|
|
||||||
|
|
||||||
# # -------------------------------------
|
|
||||||
# def default_lora_dir() -> Path:
|
|
||||||
# return config.root_path / "loras"
|
|
||||||
|
|
||||||
# # -------------------------------------
|
|
||||||
# def default_controlnet_dir() -> Path:
|
|
||||||
# return config.root_path / "controlnets"
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||||
opt = default_startup_options(initfile)
|
opt = default_startup_options(initfile)
|
||||||
|
@ -70,8 +70,8 @@ class ModelInstallList:
|
|||||||
class InstallSelections():
|
class InstallSelections():
|
||||||
install_models: List[str]= field(default_factory=list)
|
install_models: List[str]= field(default_factory=list)
|
||||||
remove_models: List[str]=field(default_factory=list)
|
remove_models: List[str]=field(default_factory=list)
|
||||||
scan_directory: Path = None
|
# scan_directory: Path = None
|
||||||
autoscan_on_startup: bool=False
|
# autoscan_on_startup: bool=False
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelLoadInfo():
|
class ModelLoadInfo():
|
||||||
@ -155,8 +155,6 @@ class ModelInstall(object):
|
|||||||
def install(self, selections: InstallSelections):
|
def install(self, selections: InstallSelections):
|
||||||
job = 1
|
job = 1
|
||||||
jobs = len(selections.remove_models) + len(selections.install_models)
|
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||||
if selections.scan_directory:
|
|
||||||
jobs += 1
|
|
||||||
|
|
||||||
# remove requested models
|
# remove requested models
|
||||||
for key in selections.remove_models:
|
for key in selections.remove_models:
|
||||||
@ -171,18 +169,8 @@ class ModelInstall(object):
|
|||||||
self.heuristic_install(path)
|
self.heuristic_install(path)
|
||||||
job += 1
|
job += 1
|
||||||
|
|
||||||
# import from the scan directory, if any
|
|
||||||
if path := selections.scan_directory:
|
|
||||||
logger.info(f'Scanning and importing models from directory {path} [{job}/{jobs}]')
|
|
||||||
self.heuristic_install(path)
|
|
||||||
|
|
||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
if selections.autoscan_on_startup and Path(selections.scan_directory).is_dir():
|
|
||||||
update_autoimport_dir(selections.scan_directory)
|
|
||||||
else:
|
|
||||||
update_autoimport_dir(None)
|
|
||||||
|
|
||||||
def heuristic_install(self,
|
def heuristic_install(self,
|
||||||
model_path_id_or_url: Union[str,Path],
|
model_path_id_or_url: Union[str,Path],
|
||||||
models_installed: Set[Path]=None)->Set[Path]:
|
models_installed: Set[Path]=None)->Set[Path]:
|
||||||
@ -228,7 +216,7 @@ class ModelInstall(object):
|
|||||||
# the model from being probed twice in the event that it has already been probed.
|
# the model from being probed twice in the event that it has already been probed.
|
||||||
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
|
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
|
||||||
try:
|
try:
|
||||||
logger.info(f'Probing {path}')
|
# logger.debug(f'Probing {path}')
|
||||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
model_name = path.stem if info.format=='checkpoint' else path.name
|
model_name = path.stem if info.format=='checkpoint' else path.name
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
@ -237,7 +225,7 @@ class ModelInstall(object):
|
|||||||
self.mgr.add_model(model_name = model_name,
|
self.mgr.add_model(model_name = model_name,
|
||||||
base_model = info.base_type,
|
base_model = info.base_type,
|
||||||
model_type = info.model_type,
|
model_type = info.model_type,
|
||||||
model_attributes = attributes
|
model_attributes = attributes,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f'{str(e)} Skipping registration.')
|
logger.warning(f'{str(e)} Skipping registration.')
|
||||||
@ -309,11 +297,11 @@ class ModelInstall(object):
|
|||||||
return location.stem
|
return location.stem
|
||||||
|
|
||||||
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
|
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
|
||||||
# convoluted way to retrieve the description from datasets
|
model_name = path.name if path.is_dir() else path.stem
|
||||||
description = f'{info.base_type.value} {info.model_type.value} model'
|
description = f'{info.base_type.value} {info.model_type.value} model {model_name}'
|
||||||
if key := self.reverse_paths.get(self.current_id):
|
if key := self.reverse_paths.get(self.current_id):
|
||||||
if key in self.datasets:
|
if key in self.datasets:
|
||||||
description = self.datasets[key]['description']
|
description = self.datasets[key].get('description') or description
|
||||||
|
|
||||||
rel_path = self.relative_to_root(path)
|
rel_path = self.relative_to_root(path)
|
||||||
|
|
||||||
@ -395,23 +383,6 @@ class ModelInstall(object):
|
|||||||
'''
|
'''
|
||||||
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
|
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
|
||||||
|
|
||||||
def update_autoimport_dir(autodir: Path):
|
|
||||||
'''
|
|
||||||
Update the "autoimport_dir" option in invokeai.yaml
|
|
||||||
'''
|
|
||||||
with open('log.txt','a') as f:
|
|
||||||
print(f'autodir = {autodir}',file=f)
|
|
||||||
|
|
||||||
invokeai_config_path = config.init_file_path
|
|
||||||
conf = OmegaConf.load(invokeai_config_path)
|
|
||||||
conf.InvokeAI.Paths.autoimport_dir = str(autodir) if autodir else None
|
|
||||||
yaml = OmegaConf.to_yaml(conf)
|
|
||||||
tmpfile = invokeai_config_path.parent / "new_config.tmp"
|
|
||||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
|
||||||
outfile.write(yaml)
|
|
||||||
tmpfile.replace(invokeai_config_path)
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def yes_or_no(prompt: str, default_yes=True):
|
def yes_or_no(prompt: str, default_yes=True):
|
||||||
default = "y" if default_yes else "n"
|
default = "y" if default_yes else "n"
|
||||||
|
@ -168,11 +168,27 @@ structure at initialization time by scanning the models directory. The
|
|||||||
in-memory data structure can be resynchronized by calling
|
in-memory data structure can be resynchronized by calling
|
||||||
`manager.scan_models_directory()`.
|
`manager.scan_models_directory()`.
|
||||||
|
|
||||||
Files and folders placed inside the `autoimport_dir` (path defined in
|
Files and folders placed inside the `autoimport` paths (paths
|
||||||
`invokeai.yaml`, defaulting to `ROOTDIR/autoimport` will also be
|
defined in `invokeai.yaml`) will also be scanned for new models at
|
||||||
scanned for new models at initialization time and added to
|
initialization time and added to `models.yaml`. Files will not be
|
||||||
`models.yaml`. Files will not be moved from this location but
|
moved from this location but preserved in-place. These directories
|
||||||
preserved in-place.
|
are:
|
||||||
|
|
||||||
|
configuration default description
|
||||||
|
------------- ------- -----------
|
||||||
|
autoimport_dir autoimport/main main models
|
||||||
|
lora_dir autoimport/lora LoRA/LyCORIS models
|
||||||
|
embedding_dir autoimport/embedding TI embeddings
|
||||||
|
controlnet_dir autoimport/controlnet ControlNet models
|
||||||
|
|
||||||
|
In actuality, models located in any of these directories are scanned
|
||||||
|
to determine their type, so it isn't strictly necessary to organize
|
||||||
|
the different types in this way. This entry in `invokeai.yaml` will
|
||||||
|
recursively scan all subdirectories within `autoimport`, scan models
|
||||||
|
files it finds, and import them if recognized.
|
||||||
|
|
||||||
|
Paths:
|
||||||
|
autoimport_dir: autoimport
|
||||||
|
|
||||||
A model can be manually added using `add_model()` using the model's
|
A model can be manually added using `add_model()` using the model's
|
||||||
name, base model, type and a dict of model attributes. See
|
name, base model, type and a dict of model attributes. See
|
||||||
@ -208,6 +224,7 @@ checkpoint or safetensors file.
|
|||||||
|
|
||||||
The path points to a file or directory on disk. If a relative path,
|
The path points to a file or directory on disk. If a relative path,
|
||||||
the root is the InvokeAI ROOTDIR.
|
the root is the InvokeAI ROOTDIR.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -566,7 +583,7 @@ class ModelManager(object):
|
|||||||
model_config = model_class.create_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
if clobber or model_key not in self.models:
|
if model_key in self.models and not clobber:
|
||||||
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
||||||
|
|
||||||
old_model = self.models.pop(model_key, None)
|
old_model = self.models.pop(model_key, None)
|
||||||
@ -660,7 +677,7 @@ class ModelManager(object):
|
|||||||
):
|
):
|
||||||
loaded_files = set()
|
loaded_files = set()
|
||||||
new_models_found = False
|
new_models_found = False
|
||||||
|
|
||||||
with Chdir(self.app_config.root_path):
|
with Chdir(self.app_config.root_path):
|
||||||
for model_key, model_config in list(self.models.items()):
|
for model_key, model_config in list(self.models.items()):
|
||||||
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
@ -697,49 +714,74 @@ class ModelManager(object):
|
|||||||
|
|
||||||
if model_path.is_relative_to(self.app_config.root_path):
|
if model_path.is_relative_to(self.app_config.root_path):
|
||||||
model_path = model_path.relative_to(self.app_config.root_path)
|
model_path = model_path.relative_to(self.app_config.root_path)
|
||||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
try:
|
||||||
self.models[model_key] = model_config
|
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||||
new_models_found = True
|
self.models[model_key] = model_config
|
||||||
|
new_models_found = True
|
||||||
|
except NotImplementedError as e:
|
||||||
|
self.logger.warning(e)
|
||||||
|
|
||||||
imported_models = self.autoimport()
|
imported_models = self.autoimport()
|
||||||
|
|
||||||
if (new_models_found or imported_models) and self.config_path:
|
if (new_models_found or imported_models) and self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
def autoimport(self):
|
def autoimport(self)->set[Path]:
|
||||||
'''
|
'''
|
||||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||||
'''
|
'''
|
||||||
# avoid circular import
|
# avoid circular import
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||||
|
|
||||||
installer = ModelInstall(config = self.app_config,
|
installer = ModelInstall(config = self.app_config,
|
||||||
model_manager = self)
|
model_manager = self,
|
||||||
|
prediction_type_helper = ask_user_for_prediction_type,
|
||||||
|
)
|
||||||
|
|
||||||
installed = set()
|
installed = set()
|
||||||
if not self.app_config.autoimport_dir:
|
|
||||||
return installed
|
|
||||||
|
|
||||||
autodir = self.app_config.root_path / self.app_config.autoimport_dir
|
|
||||||
if not (autodir and autodir.exists()):
|
|
||||||
return installed
|
|
||||||
|
|
||||||
known_paths = {(self.app_config.root_path / x['path']).resolve() for x in self.list_models()}
|
|
||||||
scanned_dirs = set()
|
scanned_dirs = set()
|
||||||
for root, dirs, files in os.walk(autodir):
|
|
||||||
for d in dirs:
|
config = self.app_config
|
||||||
path = Path(root) / d
|
known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
|
||||||
if path in known_paths:
|
|
||||||
continue
|
for autodir in [config.autoimport_dir,
|
||||||
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
config.lora_dir,
|
||||||
installed.update(installer.heuristic_install(path))
|
config.embedding_dir,
|
||||||
scanned_dirs.add(path)
|
config.controlnet_dir]:
|
||||||
|
if autodir is None:
|
||||||
for f in files:
|
continue
|
||||||
path = Path(root) / f
|
|
||||||
if path in known_paths or path.parent in scanned_dirs:
|
self.logger.info(f'Scanning {autodir} for models to import')
|
||||||
continue
|
|
||||||
if path.suffix in {'.ckpt','.bin','.pth','.safetensors'}:
|
autodir = self.app_config.root_path / autodir
|
||||||
installed.update(installer.heuristic_install(path))
|
if not autodir.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
items_scanned = 0
|
||||||
|
new_models_found = set()
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(autodir):
|
||||||
|
items_scanned += len(dirs) + len(files)
|
||||||
|
for d in dirs:
|
||||||
|
path = Path(root) / d
|
||||||
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
|
scanned_dirs.add(path)
|
||||||
|
continue
|
||||||
|
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
|
||||||
|
new_models_found.update(installer.heuristic_install(path))
|
||||||
|
scanned_dirs.add(path)
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
path = Path(root) / f
|
||||||
|
if path in known_paths or path.parent in scanned_dirs:
|
||||||
|
continue
|
||||||
|
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
|
||||||
|
new_models_found.update(installer.heuristic_install(path))
|
||||||
|
|
||||||
|
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
|
||||||
|
installed.update(new_models_found)
|
||||||
|
|
||||||
return installed
|
return installed
|
||||||
|
|
||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
|
@ -22,7 +22,7 @@ class ModelProbeInfo(object):
|
|||||||
variant_type: ModelVariantType
|
variant_type: ModelVariantType
|
||||||
prediction_type: SchedulerPredictionType
|
prediction_type: SchedulerPredictionType
|
||||||
upcast_attention: bool
|
upcast_attention: bool
|
||||||
format: Literal['diffusers','checkpoint']
|
format: Literal['diffusers','checkpoint', 'lycoris']
|
||||||
image_size: int
|
image_size: int
|
||||||
|
|
||||||
class ProbeBase(object):
|
class ProbeBase(object):
|
||||||
@ -75,22 +75,23 @@ class ModelProbe(object):
|
|||||||
between V2-Base and V2-768 SD models.
|
between V2-Base and V2-768 SD models.
|
||||||
'''
|
'''
|
||||||
if model_path:
|
if model_path:
|
||||||
format = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
|
||||||
else:
|
else:
|
||||||
format = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
|
||||||
|
|
||||||
model_info = None
|
model_info = None
|
||||||
try:
|
try:
|
||||||
model_type = cls.get_model_type_from_folder(model_path, model) \
|
model_type = cls.get_model_type_from_folder(model_path, model) \
|
||||||
if format == 'diffusers' \
|
if format_type == 'diffusers' \
|
||||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||||
probe_class = cls.PROBES[format].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
if not probe_class:
|
if not probe_class:
|
||||||
return None
|
return None
|
||||||
probe = probe_class(model_path, model, prediction_type_helper)
|
probe = probe_class(model_path, model, prediction_type_helper)
|
||||||
base_type = probe.get_base_type()
|
base_type = probe.get_base_type()
|
||||||
variant_type = probe.get_variant_type()
|
variant_type = probe.get_variant_type()
|
||||||
prediction_type = probe.get_scheduler_prediction_type()
|
prediction_type = probe.get_scheduler_prediction_type()
|
||||||
|
format = probe.get_format()
|
||||||
model_info = ModelProbeInfo(
|
model_info = ModelProbeInfo(
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
base_type = base_type,
|
base_type = base_type,
|
||||||
@ -116,10 +117,10 @@ class ModelProbe(object):
|
|||||||
if model_path.name == "learned_embeds.bin":
|
if model_path.name == "learned_embeds.bin":
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
checkpoint = checkpoint or read_checkpoint_meta(model_path, scan=True)
|
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||||
checkpoint = checkpoint.get("state_dict", checkpoint)
|
ckpt = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
for key in checkpoint.keys():
|
for key in ckpt.keys():
|
||||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||||
return ModelType.Main
|
return ModelType.Main
|
||||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||||
@ -133,7 +134,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# diffusers-ti
|
# diffusers-ti
|
||||||
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()):
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
raise ValueError("Unable to determine model type")
|
raise ValueError("Unable to determine model type")
|
||||||
@ -201,6 +202,9 @@ class ProbeBase(object):
|
|||||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_format(self)->str:
|
||||||
|
pass
|
||||||
|
|
||||||
class CheckpointProbeBase(ProbeBase):
|
class CheckpointProbeBase(ProbeBase):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
checkpoint_path: Path,
|
checkpoint_path: Path,
|
||||||
@ -214,6 +218,9 @@ class CheckpointProbeBase(ProbeBase):
|
|||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_format(self)->str:
|
||||||
|
return 'checkpoint'
|
||||||
|
|
||||||
def get_variant_type(self)-> ModelVariantType:
|
def get_variant_type(self)-> ModelVariantType:
|
||||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
|
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
|
||||||
if model_type != ModelType.Main:
|
if model_type != ModelType.Main:
|
||||||
@ -255,7 +262,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
return SchedulerPredictionType.Epsilon
|
return SchedulerPredictionType.Epsilon
|
||||||
elif checkpoint["global_step"] == 110000:
|
elif checkpoint["global_step"] == 110000:
|
||||||
return SchedulerPredictionType.VPrediction
|
return SchedulerPredictionType.VPrediction
|
||||||
if self.checkpoint_path and self.helper:
|
if self.checkpoint_path and self.helper \
|
||||||
|
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
|
||||||
return self.helper(self.checkpoint_path)
|
return self.helper(self.checkpoint_path)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -266,6 +274,9 @@ class VaeCheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_format(self)->str:
|
||||||
|
return 'lycoris'
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
@ -285,6 +296,9 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_format(self)->str:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
if 'string_to_token' in checkpoint:
|
if 'string_to_token' in checkpoint:
|
||||||
@ -331,6 +345,9 @@ class FolderProbeBase(ProbeBase):
|
|||||||
def get_variant_type(self)->ModelVariantType:
|
def get_variant_type(self)->ModelVariantType:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
def get_format(self)->str:
|
||||||
|
return 'diffusers'
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
if self.model:
|
if self.model:
|
||||||
@ -386,6 +403,9 @@ class VaeFolderProbe(FolderProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
class TextualInversionFolderProbe(FolderProbeBase):
|
class TextualInversionFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self)->str:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_base_type(self)->BaseModelType:
|
def get_base_type(self)->BaseModelType:
|
||||||
path = self.folder_path / 'learned_embeds.bin'
|
path = self.folder_path / 'learned_embeds.bin'
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
|
@ -397,7 +397,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
|||||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||||
else:
|
else:
|
||||||
if scan:
|
if scan:
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(path)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
|
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
|
||||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||||
|
@ -69,7 +69,7 @@ class StableDiffusion1Model(DiffusersModel):
|
|||||||
in_channels = unet_config['in_channels']
|
in_channels = unet_config['in_channels']
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
||||||
@ -259,8 +259,8 @@ def _convert_ckpt_and_cache(
|
|||||||
"""
|
"""
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
weights = app_config.root_dir / model_config.path
|
weights = app_config.root_path / model_config.path
|
||||||
config_file = app_config.root_dir / model_config.config
|
config_file = app_config.root_path / model_config.config
|
||||||
output_path = Path(output_path)
|
output_path = Path(output_path)
|
||||||
|
|
||||||
# return cached version if it exists
|
# return cached version if it exists
|
||||||
|
@ -215,10 +215,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ControlNetData:
|
class ControlNetData:
|
||||||
model: ControlNetModel = Field(default=None)
|
model: ControlNetModel = Field(default=None)
|
||||||
image_tensor: torch.Tensor= Field(default=None)
|
image_tensor: torch.Tensor = Field(default=None)
|
||||||
weight: Union[float, List[float]]= Field(default=1.0)
|
weight: Union[float, List[float]] = Field(default=1.0)
|
||||||
begin_step_percent: float = Field(default=0.0)
|
begin_step_percent: float = Field(default=0.0)
|
||||||
end_step_percent: float = Field(default=1.0)
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
control_mode: str = Field(default="balanced")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
@ -599,48 +601,68 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
|
||||||
|
|
||||||
# default is no controlnet, so set controlnet processing output to None
|
# default is no controlnet, so set controlnet processing output to None
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
if control_data is not None:
|
if control_data is not None:
|
||||||
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
|
|
||||||
# if conditioning_data.guidance_scale > 1.0:
|
|
||||||
if conditioning_data.guidance_scale is not None:
|
|
||||||
# expand the latents input to control model if doing classifier free guidance
|
|
||||||
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
|
||||||
# classifier_free_guidance is <= 1.0 ?)
|
|
||||||
latent_control_input = torch.cat([latent_model_input] * 2)
|
|
||||||
else:
|
|
||||||
latent_control_input = latent_model_input
|
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
# and MultiControlNet (multiple ControlNetData in list)
|
# and MultiControlNet (multiple ControlNetData in list)
|
||||||
for i, control_datum in enumerate(control_data):
|
for i, control_datum in enumerate(control_data):
|
||||||
# print("controlnet", i, "==>", type(control_datum))
|
control_mode = control_datum.control_mode
|
||||||
|
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
|
||||||
|
# that are combined at higher level to make control_mode enum
|
||||||
|
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
|
||||||
|
# or default weighting (if False)
|
||||||
|
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
|
||||||
|
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
|
||||||
|
# or the default both conditional and unconditional (if False)
|
||||||
|
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||||
|
|
||||||
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
|
||||||
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
|
||||||
# only apply controlnet if current step is within the controlnet's begin/end step range
|
# only apply controlnet if current step is within the controlnet's begin/end step range
|
||||||
if step_index >= first_control_step and step_index <= last_control_step:
|
if step_index >= first_control_step and step_index <= last_control_step:
|
||||||
# print("running controlnet", i, "for step", step_index)
|
|
||||||
|
if cfg_injection:
|
||||||
|
control_latent_input = unet_latent_input
|
||||||
|
else:
|
||||||
|
# expand the latents input to control model if doing classifier free guidance
|
||||||
|
# (which I think for now is always true, there is conditional elsewhere that stops execution if
|
||||||
|
# classifier_free_guidance is <= 1.0 ?)
|
||||||
|
control_latent_input = torch.cat([unet_latent_input] * 2)
|
||||||
|
|
||||||
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
|
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
|
||||||
|
else:
|
||||||
|
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
|
||||||
|
conditioning_data.text_embeddings])
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
controlnet_weight = control_datum.weight[step_index]
|
controlnet_weight = control_datum.weight[step_index]
|
||||||
else:
|
else:
|
||||||
# if controlnet has a single weight, use it for all steps
|
# if controlnet has a single weight, use it for all steps
|
||||||
controlnet_weight = control_datum.weight
|
controlnet_weight = control_datum.weight
|
||||||
|
|
||||||
|
# controlnet(s) inference
|
||||||
down_samples, mid_sample = control_datum.model(
|
down_samples, mid_sample = control_datum.model(
|
||||||
sample=latent_control_input,
|
sample=control_latent_input,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
conditioning_data.text_embeddings]),
|
|
||||||
controlnet_cond=control_datum.image_tensor,
|
controlnet_cond=control_datum.image_tensor,
|
||||||
conditioning_scale=controlnet_weight,
|
conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
|
||||||
# cross_attention_kwargs,
|
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
||||||
guess_mode=False,
|
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
|
if cfg_injection:
|
||||||
|
# Inferred ControlNet only for the conditional batch.
|
||||||
|
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||||
|
# add 0 to the unconditional batch to keep it unchanged.
|
||||||
|
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
||||||
|
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
||||||
|
|
||||||
if down_block_res_samples is None and mid_block_res_sample is None:
|
if down_block_res_samples is None and mid_block_res_sample is None:
|
||||||
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
||||||
else:
|
else:
|
||||||
@ -653,11 +675,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
noise_pred = self.invokeai_diffuser.do_diffusion_step(
|
||||||
latent_model_input,
|
x=unet_latent_input,
|
||||||
t,
|
sigma=t,
|
||||||
conditioning_data.unconditioned_embeddings,
|
unconditioning=conditioning_data.unconditioned_embeddings,
|
||||||
conditioning_data.text_embeddings,
|
conditioning=conditioning_data.text_embeddings,
|
||||||
conditioning_data.guidance_scale,
|
unconditional_guidance_scale=conditioning_data.guidance_scale,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
|
||||||
@ -962,6 +984,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
|
control_mode="balanced"
|
||||||
):
|
):
|
||||||
|
|
||||||
if not isinstance(image, torch.Tensor):
|
if not isinstance(image, torch.Tensor):
|
||||||
@ -992,6 +1015,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
repeat_by = num_images_per_prompt
|
repeat_by = num_images_per_prompt
|
||||||
image = image.repeat_interleave(repeat_by, dim=0)
|
image = image.repeat_interleave(repeat_by, dim=0)
|
||||||
image = image.to(device=device, dtype=dtype)
|
image = image.to(device=device, dtype=dtype)
|
||||||
if do_classifier_free_guidance:
|
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||||
|
if do_classifier_free_guidance and not cfg_injection:
|
||||||
image = torch.cat([image] * 2)
|
image = torch.cat([image] * 2)
|
||||||
return image
|
return image
|
||||||
|
@ -131,7 +131,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
window_width=window_width,
|
window_width=window_width,
|
||||||
exclude = self.starter_models
|
exclude = self.starter_models
|
||||||
)
|
)
|
||||||
self.pipeline_models['autoload_pending'] = True
|
# self.pipeline_models['autoload_pending'] = True
|
||||||
bottom_of_table = max(bottom_of_table,self.nextrely)
|
bottom_of_table = max(bottom_of_table,self.nextrely)
|
||||||
|
|
||||||
self.nextrely = top_of_table
|
self.nextrely = top_of_table
|
||||||
@ -316,31 +316,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
label = "Directory to scan for models to automatically import (<tab> autocompletes):"
|
|
||||||
self.nextrely += 1
|
|
||||||
widgets.update(
|
|
||||||
autoload_directory = self.add_widget_intelligent(
|
|
||||||
FileBox,
|
|
||||||
max_height=3,
|
|
||||||
name=label,
|
|
||||||
value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else None,
|
|
||||||
select_dir=True,
|
|
||||||
must_exist=True,
|
|
||||||
use_two_lines=False,
|
|
||||||
labelColor="DANGER",
|
|
||||||
begin_entry_at=len(label)+1,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
widgets.update(
|
|
||||||
autoscan_on_startup = self.add_widget_intelligent(
|
|
||||||
npyscreen.Checkbox,
|
|
||||||
name="Scan and import from this directory each time InvokeAI starts",
|
|
||||||
value=config.autoimport_dir is not None,
|
|
||||||
relx=4,
|
|
||||||
scroll_exit=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return widgets
|
return widgets
|
||||||
|
|
||||||
def resize(self):
|
def resize(self):
|
||||||
@ -501,8 +476,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
# rebuild the form, saving and restoring some of the fields that need to be preserved.
|
||||||
saved_messages = self.monitor.entry_widget.values
|
saved_messages = self.monitor.entry_widget.values
|
||||||
autoload_dir = str(config.root_path / self.pipeline_models['autoload_directory'].value)
|
# autoload_dir = str(config.root_path / self.pipeline_models['autoload_directory'].value)
|
||||||
autoscan = self.pipeline_models['autoscan_on_startup'].value
|
# autoscan = self.pipeline_models['autoscan_on_startup'].value
|
||||||
|
|
||||||
app.main_form = app.addForm(
|
app.main_form = app.addForm(
|
||||||
"MAIN", addModelsForm, name="Install Stable Diffusion Models", multipage=self.multipage,
|
"MAIN", addModelsForm, name="Install Stable Diffusion Models", multipage=self.multipage,
|
||||||
@ -511,8 +486,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
app.main_form.monitor.entry_widget.values = saved_messages
|
app.main_form.monitor.entry_widget.values = saved_messages
|
||||||
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
|
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
|
||||||
app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
|
# app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
|
||||||
app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
|
# app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
|
||||||
|
|
||||||
def marshall_arguments(self):
|
def marshall_arguments(self):
|
||||||
"""
|
"""
|
||||||
@ -546,17 +521,17 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
selections.install_models.extend(downloads.value.split())
|
selections.install_models.extend(downloads.value.split())
|
||||||
|
|
||||||
# load directory and whether to scan on startup
|
# load directory and whether to scan on startup
|
||||||
if self.parentApp.autoload_pending:
|
# if self.parentApp.autoload_pending:
|
||||||
selections.scan_directory = str(config.root_path / self.pipeline_models['autoload_directory'].value)
|
# selections.scan_directory = str(config.root_path / self.pipeline_models['autoload_directory'].value)
|
||||||
self.parentApp.autoload_pending = False
|
# self.parentApp.autoload_pending = False
|
||||||
selections.autoscan_on_startup = self.pipeline_models['autoscan_on_startup'].value
|
# selections.autoscan_on_startup = self.pipeline_models['autoscan_on_startup'].value
|
||||||
|
|
||||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||||
def __init__(self,opt):
|
def __init__(self,opt):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.program_opts = opt
|
self.program_opts = opt
|
||||||
self.user_cancelled = False
|
self.user_cancelled = False
|
||||||
self.autoload_pending = True
|
# self.autoload_pending = True
|
||||||
self.install_selections = InstallSelections()
|
self.install_selections = InstallSelections()
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
@ -578,20 +553,20 @@ class StderrToMessage():
|
|||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
def ask_user_for_prediction_type(model_path: Path,
|
def ask_user_for_prediction_type(model_path: Path,
|
||||||
tui_conn: Connection=None
|
tui_conn: Connection=None
|
||||||
)->Path:
|
)->SchedulerPredictionType:
|
||||||
if tui_conn:
|
if tui_conn:
|
||||||
logger.debug('Waiting for user response...')
|
logger.debug('Waiting for user response...')
|
||||||
return _ask_user_for_pt_tui(model_path, tui_conn)
|
return _ask_user_for_pt_tui(model_path, tui_conn)
|
||||||
else:
|
else:
|
||||||
return _ask_user_for_pt_cmdline(model_path)
|
return _ask_user_for_pt_cmdline(model_path)
|
||||||
|
|
||||||
def _ask_user_for_pt_cmdline(model_path):
|
def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType:
|
||||||
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
|
||||||
print(
|
print(
|
||||||
f"""
|
f"""
|
||||||
Please select the type of the V2 checkpoint named {model_path.name}:
|
Please select the type of the V2 checkpoint named {model_path.name}:
|
||||||
[1] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file)
|
[1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base)
|
||||||
[2] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file)
|
[2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768)
|
||||||
[3] Skip this model and come back later.
|
[3] Skip this model and come back later.
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
@ -608,7 +583,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
|
|||||||
return
|
return
|
||||||
return choice
|
return choice
|
||||||
|
|
||||||
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path:
|
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType:
|
||||||
try:
|
try:
|
||||||
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
|
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
|
||||||
# note that we don't do any status checking here
|
# note that we don't do any status checking here
|
||||||
@ -810,13 +785,15 @@ def main():
|
|||||||
logger.error(
|
logger.error(
|
||||||
"Insufficient vertical space for the interface. Please make your window taller and try again"
|
"Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||||
)
|
)
|
||||||
elif str(e).startswith("addwstr"):
|
input('Press any key to continue...')
|
||||||
|
except Exception as e:
|
||||||
|
if str(e).startswith("addwstr"):
|
||||||
logger.error(
|
logger.error(
|
||||||
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
else:
|
||||||
print(f'An exception has occurred: {str(e)} Details:')
|
print(f'An exception has occurred: {str(e)} Details:')
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
input('Press any key to continue...')
|
input('Press any key to continue...')
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +42,18 @@ def set_terminal_size(columns: int, lines: int, launch_command: str=None):
|
|||||||
elif OS in ["Darwin", "Linux"]:
|
elif OS in ["Darwin", "Linux"]:
|
||||||
_set_terminal_size_unix(width,height)
|
_set_terminal_size_unix(width,height)
|
||||||
|
|
||||||
|
# check whether it worked....
|
||||||
|
ts = get_terminal_size()
|
||||||
|
pause = False
|
||||||
|
if ts.columns < columns:
|
||||||
|
print('\033[1mThis window is too narrow for the user interface. Please make it wider.\033[0m')
|
||||||
|
pause = True
|
||||||
|
if ts.lines < lines:
|
||||||
|
print('\033[1mThis window is too short for the user interface. Please make it taller.\033[0m')
|
||||||
|
pause = True
|
||||||
|
if pause:
|
||||||
|
input('Press any key to continue..')
|
||||||
|
|
||||||
def _set_terminal_size_powershell(width: int, height: int):
|
def _set_terminal_size_powershell(width: int, height: int):
|
||||||
script=f'''
|
script=f'''
|
||||||
$pshost = get-host
|
$pshost = get-host
|
||||||
|
14
invokeai/frontend/web/config/common.ts
Normal file
14
invokeai/frontend/web/config/common.ts
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import react from '@vitejs/plugin-react-swc';
|
||||||
|
import { visualizer } from 'rollup-plugin-visualizer';
|
||||||
|
import { PluginOption, UserConfig } from 'vite';
|
||||||
|
import eslint from 'vite-plugin-eslint';
|
||||||
|
import tsconfigPaths from 'vite-tsconfig-paths';
|
||||||
|
import { nodePolyfills } from 'vite-plugin-node-polyfills';
|
||||||
|
|
||||||
|
export const commonPlugins: UserConfig['plugins'] = [
|
||||||
|
react(),
|
||||||
|
eslint(),
|
||||||
|
tsconfigPaths(),
|
||||||
|
visualizer() as unknown as PluginOption,
|
||||||
|
nodePolyfills(),
|
||||||
|
];
|
@ -1,17 +1,9 @@
|
|||||||
import react from '@vitejs/plugin-react-swc';
|
import { UserConfig } from 'vite';
|
||||||
import { visualizer } from 'rollup-plugin-visualizer';
|
import { commonPlugins } from './common';
|
||||||
import { PluginOption, UserConfig } from 'vite';
|
|
||||||
import eslint from 'vite-plugin-eslint';
|
|
||||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
|
||||||
|
|
||||||
export const appConfig: UserConfig = {
|
export const appConfig: UserConfig = {
|
||||||
base: './',
|
base: './',
|
||||||
plugins: [
|
plugins: [...commonPlugins],
|
||||||
react(),
|
|
||||||
eslint(),
|
|
||||||
tsconfigPaths(),
|
|
||||||
visualizer() as unknown as PluginOption,
|
|
||||||
],
|
|
||||||
build: {
|
build: {
|
||||||
chunkSizeWarningLimit: 1500,
|
chunkSizeWarningLimit: 1500,
|
||||||
},
|
},
|
||||||
|
@ -1,19 +1,13 @@
|
|||||||
import react from '@vitejs/plugin-react-swc';
|
|
||||||
import path from 'path';
|
import path from 'path';
|
||||||
import { visualizer } from 'rollup-plugin-visualizer';
|
import { UserConfig } from 'vite';
|
||||||
import { PluginOption, UserConfig } from 'vite';
|
|
||||||
import dts from 'vite-plugin-dts';
|
import dts from 'vite-plugin-dts';
|
||||||
import eslint from 'vite-plugin-eslint';
|
|
||||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
|
||||||
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
|
||||||
|
import { commonPlugins } from './common';
|
||||||
|
|
||||||
export const packageConfig: UserConfig = {
|
export const packageConfig: UserConfig = {
|
||||||
base: './',
|
base: './',
|
||||||
plugins: [
|
plugins: [
|
||||||
react(),
|
...commonPlugins,
|
||||||
eslint(),
|
|
||||||
tsconfigPaths(),
|
|
||||||
visualizer() as unknown as PluginOption,
|
|
||||||
dts({
|
dts({
|
||||||
insertTypesEntry: true,
|
insertTypesEntry: true,
|
||||||
}),
|
}),
|
||||||
|
@ -23,7 +23,7 @@
|
|||||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||||
"build": "yarn run lint && vite build",
|
"build": "yarn run lint && vite build",
|
||||||
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/schema.d.ts -t",
|
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/api/schema.d.ts -t",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"lint:madge": "madge --circular src/main.tsx",
|
"lint:madge": "madge --circular src/main.tsx",
|
||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
@ -53,36 +53,38 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@apidevtools/swagger-parser": "^10.1.0",
|
||||||
"@chakra-ui/anatomy": "^2.1.1",
|
"@chakra-ui/anatomy": "^2.1.1",
|
||||||
"@chakra-ui/icons": "^2.0.19",
|
"@chakra-ui/icons": "^2.0.19",
|
||||||
"@chakra-ui/react": "^2.6.0",
|
"@chakra-ui/react": "^2.7.1",
|
||||||
"@chakra-ui/styled-system": "^2.9.0",
|
"@chakra-ui/styled-system": "^2.9.1",
|
||||||
"@chakra-ui/theme-tools": "^2.0.16",
|
"@chakra-ui/theme-tools": "^2.0.18",
|
||||||
"@dagrejs/graphlib": "^2.1.12",
|
"@dagrejs/graphlib": "^2.1.13",
|
||||||
"@dnd-kit/core": "^6.0.8",
|
"@dnd-kit/core": "^6.0.8",
|
||||||
"@dnd-kit/modifiers": "^6.0.1",
|
"@dnd-kit/modifiers": "^6.0.1",
|
||||||
"@emotion/react": "^11.11.1",
|
"@emotion/react": "^11.11.1",
|
||||||
"@emotion/styled": "^11.10.6",
|
"@emotion/styled": "^11.11.0",
|
||||||
"@floating-ui/react-dom": "^2.0.0",
|
"@floating-ui/react-dom": "^2.0.1",
|
||||||
"@fontsource/inter": "^4.5.15",
|
"@fontsource-variable/inter": "^5.0.3",
|
||||||
"@mantine/core": "^6.0.13",
|
"@fontsource/inter": "^5.0.3",
|
||||||
"@mantine/hooks": "^6.0.13",
|
"@mantine/core": "^6.0.14",
|
||||||
|
"@mantine/hooks": "^6.0.14",
|
||||||
"@reduxjs/toolkit": "^1.9.5",
|
"@reduxjs/toolkit": "^1.9.5",
|
||||||
"@roarr/browser-log-writer": "^1.1.5",
|
"@roarr/browser-log-writer": "^1.1.5",
|
||||||
"chakra-ui-contextmenu": "^1.0.5",
|
"chakra-ui-contextmenu": "^1.0.5",
|
||||||
"dateformat": "^5.0.3",
|
"dateformat": "^5.0.3",
|
||||||
"downshift": "^7.6.0",
|
"downshift": "^7.6.0",
|
||||||
"formik": "^2.2.9",
|
"formik": "^2.4.2",
|
||||||
"framer-motion": "^10.12.4",
|
"framer-motion": "^10.12.17",
|
||||||
"fuse.js": "^6.6.2",
|
"fuse.js": "^6.6.2",
|
||||||
"i18next": "^22.4.15",
|
"i18next": "^23.2.3",
|
||||||
"i18next-browser-languagedetector": "^7.0.1",
|
"i18next-browser-languagedetector": "^7.0.2",
|
||||||
"i18next-http-backend": "^2.2.0",
|
"i18next-http-backend": "^2.2.1",
|
||||||
"konva": "^9.0.1",
|
"konva": "^9.2.0",
|
||||||
"lodash-es": "^4.17.21",
|
"lodash-es": "^4.17.21",
|
||||||
"nanostores": "^0.9.2",
|
"nanostores": "^0.9.2",
|
||||||
"openapi-fetch": "^0.4.0",
|
"openapi-fetch": "^0.4.0",
|
||||||
"overlayscrollbars": "^2.1.1",
|
"overlayscrollbars": "^2.2.0",
|
||||||
"overlayscrollbars-react": "^0.5.0",
|
"overlayscrollbars-react": "^0.5.0",
|
||||||
"patch-package": "^7.0.0",
|
"patch-package": "^7.0.0",
|
||||||
"query-string": "^8.1.0",
|
"query-string": "^8.1.0",
|
||||||
@ -92,21 +94,21 @@
|
|||||||
"react-dom": "^18.2.0",
|
"react-dom": "^18.2.0",
|
||||||
"react-dropzone": "^14.2.3",
|
"react-dropzone": "^14.2.3",
|
||||||
"react-hotkeys-hook": "4.4.0",
|
"react-hotkeys-hook": "4.4.0",
|
||||||
"react-i18next": "^12.2.2",
|
"react-i18next": "^13.0.1",
|
||||||
"react-icons": "^4.9.0",
|
"react-icons": "^4.10.1",
|
||||||
"react-konva": "^18.2.7",
|
"react-konva": "^18.2.10",
|
||||||
"react-redux": "^8.0.5",
|
"react-redux": "^8.1.1",
|
||||||
"react-resizable-panels": "^0.0.42",
|
"react-resizable-panels": "^0.0.52",
|
||||||
"react-use": "^17.4.0",
|
"react-use": "^17.4.0",
|
||||||
"react-virtuoso": "^4.3.5",
|
"react-virtuoso": "^4.3.11",
|
||||||
"react-zoom-pan-pinch": "^3.0.7",
|
"react-zoom-pan-pinch": "^3.0.8",
|
||||||
"reactflow": "^11.7.0",
|
"reactflow": "^11.7.4",
|
||||||
"redux-dynamic-middlewares": "^2.2.0",
|
"redux-dynamic-middlewares": "^2.2.0",
|
||||||
"redux-remember": "^3.3.1",
|
"redux-remember": "^3.3.1",
|
||||||
"roarr": "^7.15.0",
|
"roarr": "^7.15.0",
|
||||||
"serialize-error": "^11.0.0",
|
"serialize-error": "^11.0.0",
|
||||||
"socket.io-client": "^4.6.0",
|
"socket.io-client": "^4.7.0",
|
||||||
"use-image": "^1.1.0",
|
"use-image": "^1.1.1",
|
||||||
"uuid": "^9.0.0",
|
"uuid": "^9.0.0",
|
||||||
"zod": "^3.21.4"
|
"zod": "^3.21.4"
|
||||||
},
|
},
|
||||||
@ -117,22 +119,22 @@
|
|||||||
"ts-toolbelt": "^9.6.0"
|
"ts-toolbelt": "^9.6.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@chakra-ui/cli": "^2.4.0",
|
"@chakra-ui/cli": "^2.4.1",
|
||||||
"@types/dateformat": "^5.0.0",
|
"@types/dateformat": "^5.0.0",
|
||||||
"@types/lodash-es": "^4.14.194",
|
"@types/lodash-es": "^4.14.194",
|
||||||
"@types/node": "^18.16.2",
|
"@types/node": "^20.3.1",
|
||||||
"@types/react": "^18.2.0",
|
"@types/react": "^18.2.14",
|
||||||
"@types/react-dom": "^18.2.1",
|
"@types/react-dom": "^18.2.6",
|
||||||
"@types/react-redux": "^7.1.25",
|
"@types/react-redux": "^7.1.25",
|
||||||
"@types/react-transition-group": "^4.4.5",
|
"@types/react-transition-group": "^4.4.6",
|
||||||
"@types/uuid": "^9.0.0",
|
"@types/uuid": "^9.0.2",
|
||||||
"@typescript-eslint/eslint-plugin": "^5.59.1",
|
"@typescript-eslint/eslint-plugin": "^5.60.0",
|
||||||
"@typescript-eslint/parser": "^5.59.1",
|
"@typescript-eslint/parser": "^5.60.0",
|
||||||
"@vitejs/plugin-react-swc": "^3.3.0",
|
"@vitejs/plugin-react-swc": "^3.3.2",
|
||||||
"axios": "^1.4.0",
|
"axios": "^1.4.0",
|
||||||
"babel-plugin-transform-imports": "^2.0.0",
|
"babel-plugin-transform-imports": "^2.0.0",
|
||||||
"concurrently": "^8.0.1",
|
"concurrently": "^8.2.0",
|
||||||
"eslint": "^8.39.0",
|
"eslint": "^8.43.0",
|
||||||
"eslint-config-prettier": "^8.8.0",
|
"eslint-config-prettier": "^8.8.0",
|
||||||
"eslint-plugin-prettier": "^4.2.1",
|
"eslint-plugin-prettier": "^4.2.1",
|
||||||
"eslint-plugin-react": "^7.32.2",
|
"eslint-plugin-react": "^7.32.2",
|
||||||
@ -140,19 +142,20 @@
|
|||||||
"form-data": "^4.0.0",
|
"form-data": "^4.0.0",
|
||||||
"husky": "^8.0.3",
|
"husky": "^8.0.3",
|
||||||
"lint-staged": "^13.2.2",
|
"lint-staged": "^13.2.2",
|
||||||
"madge": "^6.0.0",
|
"madge": "^6.1.0",
|
||||||
"openapi-types": "^12.1.0",
|
"openapi-types": "^12.1.3",
|
||||||
"openapi-typescript": "^6.2.8",
|
"openapi-typescript": "^6.2.8",
|
||||||
"openapi-typescript-codegen": "^0.24.0",
|
"openapi-typescript-codegen": "^0.24.0",
|
||||||
"postinstall-postinstall": "^2.1.0",
|
"postinstall-postinstall": "^2.1.0",
|
||||||
"prettier": "^2.8.8",
|
"prettier": "^2.8.8",
|
||||||
"rollup-plugin-visualizer": "^5.9.0",
|
"rollup-plugin-visualizer": "^5.9.2",
|
||||||
"terser": "^5.17.1",
|
"terser": "^5.18.1",
|
||||||
"ts-toolbelt": "^9.6.0",
|
"ts-toolbelt": "^9.6.0",
|
||||||
"vite": "^4.3.3",
|
"vite": "^4.3.9",
|
||||||
"vite-plugin-css-injected-by-js": "^3.1.1",
|
"vite-plugin-css-injected-by-js": "^3.1.1",
|
||||||
"vite-plugin-dts": "^2.3.0",
|
"vite-plugin-dts": "^2.3.0",
|
||||||
"vite-plugin-eslint": "^1.8.1",
|
"vite-plugin-eslint": "^1.8.1",
|
||||||
|
"vite-plugin-node-polyfills": "^0.9.0",
|
||||||
"vite-tsconfig-paths": "^4.2.0",
|
"vite-tsconfig-paths": "^4.2.0",
|
||||||
"yarn": "^1.22.19"
|
"yarn": "^1.22.19"
|
||||||
}
|
}
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
diff --git a/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js b/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js
|
|
||||||
index 937cf0d..7dcc0c0 100644
|
|
||||||
--- a/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js
|
|
||||||
+++ b/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js
|
|
||||||
@@ -50,7 +50,8 @@ async function readTheme(themeFilePath) {
|
|
||||||
project: tsConfig.configFileAbsolutePath,
|
|
||||||
compilerOptions: {
|
|
||||||
module: "CommonJS",
|
|
||||||
- esModuleInterop: true
|
|
||||||
+ esModuleInterop: true,
|
|
||||||
+ jsx: 'react'
|
|
||||||
},
|
|
||||||
transpileOnly: true,
|
|
||||||
swc: true
|
|
@ -524,7 +524,8 @@
|
|||||||
"initialImage": "Initial Image",
|
"initialImage": "Initial Image",
|
||||||
"showOptionsPanel": "Show Options Panel",
|
"showOptionsPanel": "Show Options Panel",
|
||||||
"hidePreview": "Hide Preview",
|
"hidePreview": "Hide Preview",
|
||||||
"showPreview": "Show Preview"
|
"showPreview": "Show Preview",
|
||||||
|
"controlNetControlMode": "Control Mode"
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Models",
|
"models": "Models",
|
||||||
|
@ -48,7 +48,7 @@ const App = ({
|
|||||||
const isApplicationReady = useIsApplicationReady();
|
const isApplicationReady = useIsApplicationReady();
|
||||||
|
|
||||||
const { data: pipelineModels } = useListModelsQuery({
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
model_type: 'pipeline',
|
model_type: 'main',
|
||||||
});
|
});
|
||||||
const { data: controlnetModels } = useListModelsQuery({
|
const { data: controlnetModels } = useListModelsQuery({
|
||||||
model_type: 'controlnet',
|
model_type: 'controlnet',
|
||||||
|
@ -14,7 +14,7 @@ import { invokeAIThemeColors } from 'theme/colors/invokeAI';
|
|||||||
import { lightThemeColors } from 'theme/colors/lightTheme';
|
import { lightThemeColors } from 'theme/colors/lightTheme';
|
||||||
import { oceanBlueColors } from 'theme/colors/oceanBlue';
|
import { oceanBlueColors } from 'theme/colors/oceanBlue';
|
||||||
|
|
||||||
import '@fontsource/inter/variable.css';
|
import '@fontsource-variable/inter';
|
||||||
import { MantineProvider } from '@mantine/core';
|
import { MantineProvider } from '@mantine/core';
|
||||||
import { mantineTheme } from 'mantine-theme/theme';
|
import { mantineTheme } from 'mantine-theme/theme';
|
||||||
import 'overlayscrollbars/overlayscrollbars.css';
|
import 'overlayscrollbars/overlayscrollbars.css';
|
||||||
|
@ -15,7 +15,7 @@ import { ImageDTO } from 'services/api/types';
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
|
import { nodesSelector } from 'features/nodes/store/nodesSlice';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { some } from 'lodash-es';
|
import { some } from 'lodash-es';
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ export const selectImageUsage = createSelector(
|
|||||||
[
|
[
|
||||||
generationSelector,
|
generationSelector,
|
||||||
canvasSelector,
|
canvasSelector,
|
||||||
nodesSelecter,
|
nodesSelector,
|
||||||
controlNetSelector,
|
controlNetSelector,
|
||||||
(state: RootState, image_name?: string) => image_name,
|
(state: RootState, image_name?: string) => image_name,
|
||||||
],
|
],
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { AnyAction } from '@reduxjs/toolkit';
|
import { AnyAction } from '@reduxjs/toolkit';
|
||||||
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { forEach } from 'lodash-es';
|
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||||
import { Graph } from 'services/api/types';
|
import { Graph } from 'services/api/types';
|
||||||
|
|
||||||
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
|
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
|
||||||
@ -8,17 +9,6 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
|
|||||||
if (action.payload.nodes) {
|
if (action.payload.nodes) {
|
||||||
const sanitizedNodes: Graph['nodes'] = {};
|
const sanitizedNodes: Graph['nodes'] = {};
|
||||||
|
|
||||||
// Sanitize nodes as needed
|
|
||||||
forEach(action.payload.nodes, (node, key) => {
|
|
||||||
// Don't log the whole freaking dataURL
|
|
||||||
if (node.type === 'dataURL_image') {
|
|
||||||
const { dataURL, ...rest } = node;
|
|
||||||
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
|
|
||||||
} else {
|
|
||||||
sanitizedNodes[key] = { ...node };
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
...action,
|
...action,
|
||||||
payload: { ...action.payload, nodes: sanitizedNodes },
|
payload: { ...action.payload, nodes: sanitizedNodes },
|
||||||
@ -26,5 +16,19 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (receivedOpenAPISchema.fulfilled.match(action)) {
|
||||||
|
return {
|
||||||
|
...action,
|
||||||
|
payload: '<OpenAPI schema omitted>',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nodeTemplatesBuilt.match(action)) {
|
||||||
|
return {
|
||||||
|
...action,
|
||||||
|
payload: '<Node templates omitted>',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
return action;
|
return action;
|
||||||
};
|
};
|
||||||
|
@ -82,6 +82,7 @@ import {
|
|||||||
addImageRemovedFromBoardFulfilledListener,
|
addImageRemovedFromBoardFulfilledListener,
|
||||||
addImageRemovedFromBoardRejectedListener,
|
addImageRemovedFromBoardRejectedListener,
|
||||||
} from './listeners/imageRemovedFromBoard';
|
} from './listeners/imageRemovedFromBoard';
|
||||||
|
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -205,3 +206,6 @@ addImageAddedToBoardRejectedListener();
|
|||||||
addImageRemovedFromBoardFulfilledListener();
|
addImageRemovedFromBoardFulfilledListener();
|
||||||
addImageRemovedFromBoardRejectedListener();
|
addImageRemovedFromBoardRejectedListener();
|
||||||
addBoardIdSelectedListener();
|
addBoardIdSelectedListener();
|
||||||
|
|
||||||
|
// Node schemas
|
||||||
|
addReceivedOpenAPISchemaListener();
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
import { parseSchema } from 'features/nodes/util/parseSchema';
|
||||||
|
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { size } from 'lodash-es';
|
||||||
|
|
||||||
|
const schemaLog = log.child({ namespace: 'schema' });
|
||||||
|
|
||||||
|
export const addReceivedOpenAPISchemaListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: receivedOpenAPISchema.fulfilled,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
const schemaJSON = action.payload;
|
||||||
|
|
||||||
|
schemaLog.info({ data: { schemaJSON } }, 'Dereferenced OpenAPI schema');
|
||||||
|
|
||||||
|
const nodeTemplates = parseSchema(schemaJSON);
|
||||||
|
|
||||||
|
schemaLog.info(
|
||||||
|
{ data: { nodeTemplates } },
|
||||||
|
`Built ${size(nodeTemplates)} node templates`
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(nodeTemplatesBuilt(nodeTemplates));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: receivedOpenAPISchema.rejected,
|
||||||
|
effect: (action, { dispatch, getState }) => {
|
||||||
|
schemaLog.error('Problem dereferencing OpenAPI Schema');
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -3,7 +3,7 @@ import { startAppListening } from '..';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||||
import { nodesSelecter } from 'features/nodes/store/nodesSlice';
|
import { nodesSelector } from 'features/nodes/store/nodesSlice';
|
||||||
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
|
||||||
import { forEach, uniqBy } from 'lodash-es';
|
import { forEach, uniqBy } from 'lodash-es';
|
||||||
import { imageUrlsReceived } from 'services/api/thunks/image';
|
import { imageUrlsReceived } from 'services/api/thunks/image';
|
||||||
@ -16,7 +16,7 @@ const selectAllUsedImages = createSelector(
|
|||||||
[
|
[
|
||||||
generationSelector,
|
generationSelector,
|
||||||
canvasSelector,
|
canvasSelector,
|
||||||
nodesSelecter,
|
nodesSelector,
|
||||||
controlNetSelector,
|
controlNetSelector,
|
||||||
selectImagesEntities,
|
selectImagesEntities,
|
||||||
],
|
],
|
||||||
|
@ -22,6 +22,7 @@ import boardsReducer from 'features/gallery/store/boardSlice';
|
|||||||
import configReducer from 'features/system/store/configSlice';
|
import configReducer from 'features/system/store/configSlice';
|
||||||
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
import hotkeysReducer from 'features/ui/store/hotkeysSlice';
|
||||||
import uiReducer from 'features/ui/store/uiSlice';
|
import uiReducer from 'features/ui/store/uiSlice';
|
||||||
|
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
|
||||||
|
|
||||||
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
import { listenerMiddleware } from './middleware/listenerMiddleware';
|
||||||
|
|
||||||
@ -48,6 +49,7 @@ const allReducers = {
|
|||||||
controlNet: controlNetReducer,
|
controlNet: controlNetReducer,
|
||||||
boards: boardsReducer,
|
boards: boardsReducer,
|
||||||
// session: sessionReducer,
|
// session: sessionReducer,
|
||||||
|
dynamicPrompts: dynamicPromptsReducer,
|
||||||
[api.reducerPath]: api.reducer,
|
[api.reducerPath]: api.reducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -65,6 +67,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'system',
|
'system',
|
||||||
'ui',
|
'ui',
|
||||||
'controlNet',
|
'controlNet',
|
||||||
|
'dynamicPrompts',
|
||||||
// 'boards',
|
// 'boards',
|
||||||
// 'hotkeys',
|
// 'hotkeys',
|
||||||
// 'config',
|
// 'config',
|
||||||
@ -100,3 +103,4 @@ export type AppGetState = typeof store.getState;
|
|||||||
export type RootState = ReturnType<typeof store.getState>;
|
export type RootState = ReturnType<typeof store.getState>;
|
||||||
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
|
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
|
||||||
export type AppDispatch = typeof store.dispatch;
|
export type AppDispatch = typeof store.dispatch;
|
||||||
|
export const stateSelector = (state: RootState) => state;
|
||||||
|
@ -171,6 +171,14 @@ export type AppConfig = {
|
|||||||
fineStep: number;
|
fineStep: number;
|
||||||
coarseStep: number;
|
coarseStep: number;
|
||||||
};
|
};
|
||||||
|
dynamicPrompts: {
|
||||||
|
maxPrompts: {
|
||||||
|
initial: number;
|
||||||
|
min: number;
|
||||||
|
sliderMax: number;
|
||||||
|
inputMax: number;
|
||||||
|
};
|
||||||
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
|
|||||||
borderWidth: '2px',
|
borderWidth: '2px',
|
||||||
borderColor: 'var(--invokeai-colors-base-800)',
|
borderColor: 'var(--invokeai-colors-base-800)',
|
||||||
color: 'var(--invokeai-colors-base-100)',
|
color: 'var(--invokeai-colors-base-100)',
|
||||||
padding: 10,
|
|
||||||
paddingRight: 24,
|
paddingRight: 24,
|
||||||
fontWeight: 600,
|
fontWeight: 600,
|
||||||
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
|
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },
|
||||||
|
@ -34,6 +34,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
|||||||
'&:focus': {
|
'&:focus': {
|
||||||
borderColor: 'var(--invokeai-colors-accent-600)',
|
borderColor: 'var(--invokeai-colors-accent-600)',
|
||||||
},
|
},
|
||||||
|
'&:disabled': {
|
||||||
|
backgroundColor: 'var(--invokeai-colors-base-700)',
|
||||||
|
color: 'var(--invokeai-colors-base-400)',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
dropdown: {
|
dropdown: {
|
||||||
backgroundColor: 'var(--invokeai-colors-base-800)',
|
backgroundColor: 'var(--invokeai-colors-base-800)',
|
||||||
@ -64,7 +68,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
rightSection: {
|
rightSection: {
|
||||||
width: 24,
|
width: 32,
|
||||||
},
|
},
|
||||||
})}
|
})}
|
||||||
{...rest}
|
{...rest}
|
||||||
|
@ -41,7 +41,15 @@ const IAISwitch = (props: Props) => {
|
|||||||
{...formControlProps}
|
{...formControlProps}
|
||||||
>
|
>
|
||||||
{label && (
|
{label && (
|
||||||
<FormLabel my={1} flexGrow={1} {...formLabelProps}>
|
<FormLabel
|
||||||
|
my={1}
|
||||||
|
flexGrow={1}
|
||||||
|
sx={{
|
||||||
|
cursor: isDisabled ? 'not-allowed' : 'pointer',
|
||||||
|
...formLabelProps?.sx,
|
||||||
|
}}
|
||||||
|
{...formLabelProps}
|
||||||
|
>
|
||||||
{label}
|
{label}
|
||||||
</FormLabel>
|
</FormLabel>
|
||||||
)}
|
)}
|
||||||
|
@ -9,10 +9,12 @@ type IAICanvasImageProps = {
|
|||||||
};
|
};
|
||||||
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
||||||
const { width, height, x, y, imageName } = props.canvasImage;
|
const { width, height, x, y, imageName } = props.canvasImage;
|
||||||
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken);
|
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
|
||||||
|
imageName ?? skipToken
|
||||||
|
);
|
||||||
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
|
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
|
||||||
|
|
||||||
if (!imageDTO) {
|
if (isError) {
|
||||||
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
|
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,26 +1,27 @@
|
|||||||
|
import { Box, ChakraProps, Flex } from '@chakra-ui/react';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
import { FaCopy, FaTrash } from 'react-icons/fa';
|
||||||
import {
|
import {
|
||||||
ControlNetConfig,
|
ControlNetConfig,
|
||||||
controlNetAdded,
|
controlNetAdded,
|
||||||
controlNetRemoved,
|
controlNetRemoved,
|
||||||
controlNetToggled,
|
controlNetToggled,
|
||||||
} from '../store/controlNetSlice';
|
} from '../store/controlNetSlice';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import ParamControlNetModel from './parameters/ParamControlNetModel';
|
import ParamControlNetModel from './parameters/ParamControlNetModel';
|
||||||
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
|
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
|
||||||
import { Flex, Box, ChakraProps } from '@chakra-ui/react';
|
|
||||||
import { FaCopy, FaTrash } from 'react-icons/fa';
|
|
||||||
|
|
||||||
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
|
||||||
import ControlNetImagePreview from './ControlNetImagePreview';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { useToggle } from 'react-use';
|
|
||||||
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
|
||||||
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
|
|
||||||
import IAISwitch from 'common/components/IAISwitch';
|
|
||||||
import { ChevronUpIcon } from '@chakra-ui/icons';
|
import { ChevronUpIcon } from '@chakra-ui/icons';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { useToggle } from 'react-use';
|
||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import ControlNetImagePreview from './ControlNetImagePreview';
|
||||||
|
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
|
||||||
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
|
||||||
|
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
|
||||||
|
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
|
||||||
|
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
|
||||||
|
|
||||||
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
|
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
weight,
|
weight,
|
||||||
beginStepPct,
|
beginStepPct,
|
||||||
endStepPct,
|
endStepPct,
|
||||||
|
controlMode,
|
||||||
controlImage,
|
controlImage,
|
||||||
processedControlImage,
|
processedControlImage,
|
||||||
processorNode,
|
processorNode,
|
||||||
@ -137,48 +139,54 @@ const ControlNet = (props: ControlNetProps) => {
|
|||||||
</Flex>
|
</Flex>
|
||||||
{isEnabled && (
|
{isEnabled && (
|
||||||
<>
|
<>
|
||||||
<Flex sx={{ gap: 4, w: 'full' }}>
|
<Flex sx={{ w: 'full', flexDirection: 'column' }}>
|
||||||
<Flex
|
<Flex sx={{ gap: 4, w: 'full' }}>
|
||||||
sx={{
|
|
||||||
flexDir: 'column',
|
|
||||||
gap: 2,
|
|
||||||
w: 'full',
|
|
||||||
h: isExpanded ? 28 : 24,
|
|
||||||
paddingInlineStart: 1,
|
|
||||||
paddingInlineEnd: isExpanded ? 1 : 0,
|
|
||||||
pb: 2,
|
|
||||||
justifyContent: 'space-between',
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<ParamControlNetWeight
|
|
||||||
controlNetId={controlNetId}
|
|
||||||
weight={weight}
|
|
||||||
mini={!isExpanded}
|
|
||||||
/>
|
|
||||||
<ParamControlNetBeginEnd
|
|
||||||
controlNetId={controlNetId}
|
|
||||||
beginStepPct={beginStepPct}
|
|
||||||
endStepPct={endStepPct}
|
|
||||||
mini={!isExpanded}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
{!isExpanded && (
|
|
||||||
<Flex
|
<Flex
|
||||||
sx={{
|
sx={{
|
||||||
alignItems: 'center',
|
flexDir: 'column',
|
||||||
justifyContent: 'center',
|
gap: 3,
|
||||||
h: 24,
|
w: 'full',
|
||||||
w: 24,
|
paddingInlineStart: 1,
|
||||||
aspectRatio: '1/1',
|
paddingInlineEnd: isExpanded ? 1 : 0,
|
||||||
|
pb: 2,
|
||||||
|
justifyContent: 'space-between',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<ControlNetImagePreview
|
<ParamControlNetWeight
|
||||||
controlNet={props.controlNet}
|
controlNetId={controlNetId}
|
||||||
height={24}
|
weight={weight}
|
||||||
|
mini={!isExpanded}
|
||||||
|
/>
|
||||||
|
<ParamControlNetBeginEnd
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
beginStepPct={beginStepPct}
|
||||||
|
endStepPct={endStepPct}
|
||||||
|
mini={!isExpanded}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
{!isExpanded && (
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
alignItems: 'center',
|
||||||
|
justifyContent: 'center',
|
||||||
|
h: 24,
|
||||||
|
w: 24,
|
||||||
|
aspectRatio: '1/1',
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<ControlNetImagePreview
|
||||||
|
controlNet={props.controlNet}
|
||||||
|
height={24}
|
||||||
|
/>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
<ParamControlNetControlMode
|
||||||
|
controlNetId={controlNetId}
|
||||||
|
controlMode={controlMode}
|
||||||
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
{isExpanded && (
|
{isExpanded && (
|
||||||
<>
|
<>
|
||||||
<Box mt={2}>
|
<Box mt={2}>
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import {
|
||||||
|
ControlModes,
|
||||||
|
controlNetControlModeChanged,
|
||||||
|
} from 'features/controlNet/store/controlNetSlice';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
type ParamControlNetControlModeProps = {
|
||||||
|
controlNetId: string;
|
||||||
|
controlMode: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const CONTROL_MODE_DATA = [
|
||||||
|
{ label: 'Balanced', value: 'balanced' },
|
||||||
|
{ label: 'Prompt', value: 'more_prompt' },
|
||||||
|
{ label: 'Control', value: 'more_control' },
|
||||||
|
{ label: 'Mega Control', value: 'unbalanced' },
|
||||||
|
];
|
||||||
|
|
||||||
|
export default function ParamControlNetControlMode(
|
||||||
|
props: ParamControlNetControlModeProps
|
||||||
|
) {
|
||||||
|
const { controlNetId, controlMode = false } = props;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleControlModeChange = useCallback(
|
||||||
|
(controlMode: ControlModes) => {
|
||||||
|
dispatch(controlNetControlModeChanged({ controlNetId, controlMode }));
|
||||||
|
},
|
||||||
|
[controlNetId, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
label={t('parameters.controlNetControlMode')}
|
||||||
|
data={CONTROL_MODE_DATA}
|
||||||
|
value={String(controlMode)}
|
||||||
|
onChange={handleControlModeChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
@ -1,6 +1,5 @@
|
|||||||
import {
|
import {
|
||||||
ControlNetProcessorType,
|
ControlNetProcessorType,
|
||||||
RequiredCannyImageProcessorInvocation,
|
|
||||||
RequiredControlNetProcessorNode,
|
RequiredControlNetProcessorNode,
|
||||||
} from './types';
|
} from './types';
|
||||||
|
|
||||||
@ -23,7 +22,7 @@ type ControlNetProcessorsDict = Record<
|
|||||||
*
|
*
|
||||||
* TODO: Generate from the OpenAPI schema
|
* TODO: Generate from the OpenAPI schema
|
||||||
*/
|
*/
|
||||||
export const CONTROLNET_PROCESSORS = {
|
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||||
none: {
|
none: {
|
||||||
type: 'none',
|
type: 'none',
|
||||||
label: 'none',
|
label: 'none',
|
||||||
@ -174,6 +173,8 @@ export const CONTROLNET_PROCESSORS = {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
type ControlNetModelsDict = Record<string, ControlNetModel>;
|
||||||
|
|
||||||
type ControlNetModel = {
|
type ControlNetModel = {
|
||||||
type: string;
|
type: string;
|
||||||
label: string;
|
label: string;
|
||||||
@ -181,7 +182,7 @@ type ControlNetModel = {
|
|||||||
defaultProcessor?: ControlNetProcessorType;
|
defaultProcessor?: ControlNetProcessorType;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const CONTROLNET_MODELS = {
|
export const CONTROLNET_MODELS: ControlNetModelsDict = {
|
||||||
'lllyasviel/control_v11p_sd15_canny': {
|
'lllyasviel/control_v11p_sd15_canny': {
|
||||||
type: 'lllyasviel/control_v11p_sd15_canny',
|
type: 'lllyasviel/control_v11p_sd15_canny',
|
||||||
label: 'Canny',
|
label: 'Canny',
|
||||||
@ -190,6 +191,7 @@ export const CONTROLNET_MODELS = {
|
|||||||
'lllyasviel/control_v11p_sd15_inpaint': {
|
'lllyasviel/control_v11p_sd15_inpaint': {
|
||||||
type: 'lllyasviel/control_v11p_sd15_inpaint',
|
type: 'lllyasviel/control_v11p_sd15_inpaint',
|
||||||
label: 'Inpaint',
|
label: 'Inpaint',
|
||||||
|
defaultProcessor: 'none',
|
||||||
},
|
},
|
||||||
'lllyasviel/control_v11p_sd15_mlsd': {
|
'lllyasviel/control_v11p_sd15_mlsd': {
|
||||||
type: 'lllyasviel/control_v11p_sd15_mlsd',
|
type: 'lllyasviel/control_v11p_sd15_mlsd',
|
||||||
@ -209,6 +211,7 @@ export const CONTROLNET_MODELS = {
|
|||||||
'lllyasviel/control_v11p_sd15_seg': {
|
'lllyasviel/control_v11p_sd15_seg': {
|
||||||
type: 'lllyasviel/control_v11p_sd15_seg',
|
type: 'lllyasviel/control_v11p_sd15_seg',
|
||||||
label: 'Segmentation',
|
label: 'Segmentation',
|
||||||
|
defaultProcessor: 'none',
|
||||||
},
|
},
|
||||||
'lllyasviel/control_v11p_sd15_lineart': {
|
'lllyasviel/control_v11p_sd15_lineart': {
|
||||||
type: 'lllyasviel/control_v11p_sd15_lineart',
|
type: 'lllyasviel/control_v11p_sd15_lineart',
|
||||||
@ -223,6 +226,7 @@ export const CONTROLNET_MODELS = {
|
|||||||
'lllyasviel/control_v11p_sd15_scribble': {
|
'lllyasviel/control_v11p_sd15_scribble': {
|
||||||
type: 'lllyasviel/control_v11p_sd15_scribble',
|
type: 'lllyasviel/control_v11p_sd15_scribble',
|
||||||
label: 'Scribble',
|
label: 'Scribble',
|
||||||
|
defaultProcessor: 'none',
|
||||||
},
|
},
|
||||||
'lllyasviel/control_v11p_sd15_softedge': {
|
'lllyasviel/control_v11p_sd15_softedge': {
|
||||||
type: 'lllyasviel/control_v11p_sd15_softedge',
|
type: 'lllyasviel/control_v11p_sd15_softedge',
|
||||||
@ -242,10 +246,12 @@ export const CONTROLNET_MODELS = {
|
|||||||
'lllyasviel/control_v11f1e_sd15_tile': {
|
'lllyasviel/control_v11f1e_sd15_tile': {
|
||||||
type: 'lllyasviel/control_v11f1e_sd15_tile',
|
type: 'lllyasviel/control_v11f1e_sd15_tile',
|
||||||
label: 'Tile (experimental)',
|
label: 'Tile (experimental)',
|
||||||
|
defaultProcessor: 'none',
|
||||||
},
|
},
|
||||||
'lllyasviel/control_v11e_sd15_ip2p': {
|
'lllyasviel/control_v11e_sd15_ip2p': {
|
||||||
type: 'lllyasviel/control_v11e_sd15_ip2p',
|
type: 'lllyasviel/control_v11e_sd15_ip2p',
|
||||||
label: 'Pix2Pix (experimental)',
|
label: 'Pix2Pix (experimental)',
|
||||||
|
defaultProcessor: 'none',
|
||||||
},
|
},
|
||||||
'CrucibleAI/ControlNetMediaPipeFace': {
|
'CrucibleAI/ControlNetMediaPipeFace': {
|
||||||
type: 'CrucibleAI/ControlNetMediaPipeFace',
|
type: 'CrucibleAI/ControlNetMediaPipeFace',
|
||||||
|
@ -18,12 +18,19 @@ import { forEach } from 'lodash-es';
|
|||||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||||
import { appSocketInvocationError } from 'services/events/actions';
|
import { appSocketInvocationError } from 'services/events/actions';
|
||||||
|
|
||||||
|
export type ControlModes =
|
||||||
|
| 'balanced'
|
||||||
|
| 'more_prompt'
|
||||||
|
| 'more_control'
|
||||||
|
| 'unbalanced';
|
||||||
|
|
||||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
endStepPct: 1,
|
endStepPct: 1,
|
||||||
|
controlMode: 'balanced',
|
||||||
controlImage: null,
|
controlImage: null,
|
||||||
processedControlImage: null,
|
processedControlImage: null,
|
||||||
processorType: 'canny_image_processor',
|
processorType: 'canny_image_processor',
|
||||||
@ -39,6 +46,7 @@ export type ControlNetConfig = {
|
|||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
|
controlMode: ControlModes;
|
||||||
controlImage: string | null;
|
controlImage: string | null;
|
||||||
processedControlImage: string | null;
|
processedControlImage: string | null;
|
||||||
processorType: ControlNetProcessorType;
|
processorType: ControlNetProcessorType;
|
||||||
@ -181,6 +189,13 @@ export const controlNetSlice = createSlice({
|
|||||||
const { controlNetId, endStepPct } = action.payload;
|
const { controlNetId, endStepPct } = action.payload;
|
||||||
state.controlNets[controlNetId].endStepPct = endStepPct;
|
state.controlNets[controlNetId].endStepPct = endStepPct;
|
||||||
},
|
},
|
||||||
|
controlNetControlModeChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
|
||||||
|
) => {
|
||||||
|
const { controlNetId, controlMode } = action.payload;
|
||||||
|
state.controlNets[controlNetId].controlMode = controlMode;
|
||||||
|
},
|
||||||
controlNetProcessorParamsChanged: (
|
controlNetProcessorParamsChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
@ -307,6 +322,7 @@ export const {
|
|||||||
controlNetWeightChanged,
|
controlNetWeightChanged,
|
||||||
controlNetBeginStepPctChanged,
|
controlNetBeginStepPctChanged,
|
||||||
controlNetEndStepPctChanged,
|
controlNetEndStepPctChanged,
|
||||||
|
controlNetControlModeChanged,
|
||||||
controlNetProcessorParamsChanged,
|
controlNetProcessorParamsChanged,
|
||||||
controlNetProcessorTypeChanged,
|
controlNetProcessorTypeChanged,
|
||||||
controlNetReset,
|
controlNetReset,
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { isEnabledToggled } from '../store/slice';
|
||||||
|
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
|
||||||
|
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
|
||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { isEnabled } = state.dynamicPrompts;
|
||||||
|
|
||||||
|
return { isEnabled };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamDynamicPromptsCollapse = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { isEnabled } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const handleToggleIsEnabled = useCallback(() => {
|
||||||
|
dispatch(isEnabledToggled());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAICollapse
|
||||||
|
isOpen={isEnabled}
|
||||||
|
onToggle={handleToggleIsEnabled}
|
||||||
|
label="Dynamic Prompts"
|
||||||
|
withSwitch
|
||||||
|
>
|
||||||
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<ParamDynamicPromptsCombinatorial />
|
||||||
|
<ParamDynamicPromptsMaxPrompts />
|
||||||
|
</Flex>
|
||||||
|
</IAICollapse>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamDynamicPromptsCollapse;
|
@ -0,0 +1,36 @@
|
|||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { combinatorialToggled } from '../store/slice';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { combinatorial } = state.dynamicPrompts;
|
||||||
|
|
||||||
|
return { combinatorial };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamDynamicPromptsCombinatorial = () => {
|
||||||
|
const { combinatorial } = useAppSelector(selector);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleChange = useCallback(() => {
|
||||||
|
dispatch(combinatorialToggled());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISwitch
|
||||||
|
label="Combinatorial Generation"
|
||||||
|
isChecked={combinatorial}
|
||||||
|
onChange={handleChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamDynamicPromptsCombinatorial;
|
@ -0,0 +1,55 @@
|
|||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { maxPrompts, combinatorial } = state.dynamicPrompts;
|
||||||
|
const { min, sliderMax, inputMax } =
|
||||||
|
state.config.sd.dynamicPrompts.maxPrompts;
|
||||||
|
|
||||||
|
return { maxPrompts, min, sliderMax, inputMax, combinatorial };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamDynamicPromptsMaxPrompts = () => {
|
||||||
|
const { maxPrompts, min, sliderMax, inputMax, combinatorial } =
|
||||||
|
useAppSelector(selector);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(maxPromptsChanged(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReset = useCallback(() => {
|
||||||
|
dispatch(maxPromptsReset());
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISlider
|
||||||
|
label="Max Prompts"
|
||||||
|
isDisabled={!combinatorial}
|
||||||
|
min={min}
|
||||||
|
max={sliderMax}
|
||||||
|
value={maxPrompts}
|
||||||
|
onChange={handleChange}
|
||||||
|
sliderNumberInputProps={{ max: inputMax }}
|
||||||
|
withSliderMarks
|
||||||
|
withInput
|
||||||
|
inputReadOnly
|
||||||
|
withReset
|
||||||
|
handleReset={handleReset}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamDynamicPromptsMaxPrompts;
|
@ -0,0 +1 @@
|
|||||||
|
//
|
@ -0,0 +1,50 @@
|
|||||||
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
|
||||||
|
export interface DynamicPromptsState {
|
||||||
|
isEnabled: boolean;
|
||||||
|
maxPrompts: number;
|
||||||
|
combinatorial: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const initialDynamicPromptsState: DynamicPromptsState = {
|
||||||
|
isEnabled: false,
|
||||||
|
maxPrompts: 100,
|
||||||
|
combinatorial: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
const initialState: DynamicPromptsState = initialDynamicPromptsState;
|
||||||
|
|
||||||
|
export const dynamicPromptsSlice = createSlice({
|
||||||
|
name: 'dynamicPrompts',
|
||||||
|
initialState,
|
||||||
|
reducers: {
|
||||||
|
maxPromptsChanged: (state, action: PayloadAction<number>) => {
|
||||||
|
state.maxPrompts = action.payload;
|
||||||
|
},
|
||||||
|
maxPromptsReset: (state) => {
|
||||||
|
state.maxPrompts = initialDynamicPromptsState.maxPrompts;
|
||||||
|
},
|
||||||
|
combinatorialToggled: (state) => {
|
||||||
|
state.combinatorial = !state.combinatorial;
|
||||||
|
},
|
||||||
|
isEnabledToggled: (state) => {
|
||||||
|
state.isEnabled = !state.isEnabled;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
extraReducers: (builder) => {
|
||||||
|
//
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
isEnabledToggled,
|
||||||
|
maxPromptsChanged,
|
||||||
|
maxPromptsReset,
|
||||||
|
combinatorialToggled,
|
||||||
|
} = dynamicPromptsSlice.actions;
|
||||||
|
|
||||||
|
export default dynamicPromptsSlice.reducer;
|
||||||
|
|
||||||
|
export const dynamicPromptsSelector = (state: RootState) =>
|
||||||
|
state.dynamicPrompts;
|
@ -1,28 +1,41 @@
|
|||||||
import 'reactflow/dist/style.css';
|
import 'reactflow/dist/style.css';
|
||||||
import { memo, useCallback } from 'react';
|
import { useCallback, forwardRef } from 'react';
|
||||||
import {
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
Tooltip,
|
|
||||||
Menu,
|
|
||||||
MenuButton,
|
|
||||||
MenuList,
|
|
||||||
MenuItem,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import { FaEllipsisV } from 'react-icons/fa';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { nodeAdded } from '../store/nodesSlice';
|
import { nodeAdded, nodesSelector } from '../store/nodesSlice';
|
||||||
import { map } from 'lodash-es';
|
import { map } from 'lodash-es';
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
||||||
import { AnyInvocationType } from 'services/events/types';
|
import { AnyInvocationType } from 'services/events/types';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
|
||||||
|
|
||||||
|
type NodeTemplate = {
|
||||||
|
label: string;
|
||||||
|
value: string;
|
||||||
|
description: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
nodesSelector,
|
||||||
|
(nodes) => {
|
||||||
|
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
|
||||||
|
return {
|
||||||
|
label: template.title,
|
||||||
|
value: template.type,
|
||||||
|
description: template.description,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
return { data };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
const AddNodeMenu = () => {
|
const AddNodeMenu = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
const { data } = useAppSelector(selector);
|
||||||
const invocationTemplates = useAppSelector(
|
|
||||||
(state: RootState) => state.nodes.invocationTemplates
|
|
||||||
);
|
|
||||||
|
|
||||||
const buildInvocation = useBuildInvocation();
|
const buildInvocation = useBuildInvocation();
|
||||||
|
|
||||||
@ -46,23 +59,52 @@ const AddNodeMenu = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Menu isLazy>
|
<Flex sx={{ gap: 2, alignItems: 'center' }}>
|
||||||
<MenuButton
|
<IAIMantineMultiSelect
|
||||||
as={IAIIconButton}
|
selectOnBlur={false}
|
||||||
aria-label="Add Node"
|
placeholder="Add Node"
|
||||||
icon={<FaEllipsisV />}
|
value={[]}
|
||||||
|
data={data}
|
||||||
|
maxDropdownHeight={400}
|
||||||
|
nothingFound="No matching nodes"
|
||||||
|
itemComponent={SelectItem}
|
||||||
|
filter={(value, selected, item: NodeTemplate) =>
|
||||||
|
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
|
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
|
||||||
|
item.description.toLowerCase().includes(value.toLowerCase().trim())
|
||||||
|
}
|
||||||
|
onChange={(v) => {
|
||||||
|
v[0] && addNode(v[0] as AnyInvocationType);
|
||||||
|
}}
|
||||||
|
sx={{
|
||||||
|
width: '18rem',
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
<MenuList overflowY="scroll" height={400}>
|
</Flex>
|
||||||
{map(invocationTemplates, ({ title, description, type }, key) => {
|
|
||||||
return (
|
|
||||||
<Tooltip key={key} label={description} placement="end" hasArrow>
|
|
||||||
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
|
|
||||||
</Tooltip>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</MenuList>
|
|
||||||
</Menu>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(AddNodeMenu);
|
interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
|
||||||
|
value: string;
|
||||||
|
label: string;
|
||||||
|
description: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
|
||||||
|
({ label, description, ...others }: ItemProps, ref) => {
|
||||||
|
return (
|
||||||
|
<div ref={ref} {...others}>
|
||||||
|
<div>
|
||||||
|
<Text>{label}</Text>
|
||||||
|
<Text size="xs" color="base.600">
|
||||||
|
{description}
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
SelectItem.displayName = 'SelectItem';
|
||||||
|
|
||||||
|
export default AddNodeMenu;
|
||||||
|
@ -23,7 +23,7 @@ const ModelInputFieldComponent = (
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { data: pipelineModels } = useListModelsQuery({
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
model_type: 'pipeline',
|
model_type: 'main',
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { Panel } from 'reactflow';
|
import { Panel } from 'reactflow';
|
||||||
import NodeSearch from '../search/NodeSearch';
|
import AddNodeMenu from '../AddNodeMenu';
|
||||||
|
|
||||||
const TopLeftPanel = () => (
|
const TopLeftPanel = () => (
|
||||||
<Panel position="top-left">
|
<Panel position="top-left">
|
||||||
<NodeSearch />
|
<AddNodeMenu />
|
||||||
</Panel>
|
</Panel>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -14,9 +14,6 @@ import {
|
|||||||
import { ImageField } from 'services/api/types';
|
import { ImageField } from 'services/api/types';
|
||||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||||
import { InvocationTemplate, InvocationValue } from '../types/types';
|
import { InvocationTemplate, InvocationValue } from '../types/types';
|
||||||
import { parseSchema } from '../util/parseSchema';
|
|
||||||
import { log } from 'app/logging/useLogger';
|
|
||||||
import { size } from 'lodash-es';
|
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
|
|
||||||
@ -78,25 +75,17 @@ const nodesSlice = createSlice({
|
|||||||
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
|
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldShowGraphOverlay = action.payload;
|
state.shouldShowGraphOverlay = action.payload;
|
||||||
},
|
},
|
||||||
parsedOpenAPISchema: (state, action: PayloadAction<OpenAPIV3.Document>) => {
|
nodeTemplatesBuilt: (
|
||||||
try {
|
state,
|
||||||
const parsedSchema = parseSchema(action.payload);
|
action: PayloadAction<Record<string, InvocationTemplate>>
|
||||||
|
) => {
|
||||||
// TODO: Achtung! Side effect in a reducer!
|
state.invocationTemplates = action.payload;
|
||||||
log.info(
|
|
||||||
{ namespace: 'schema', nodes: parsedSchema },
|
|
||||||
`Parsed ${size(parsedSchema)} nodes`
|
|
||||||
);
|
|
||||||
state.invocationTemplates = parsedSchema;
|
|
||||||
} catch (err) {
|
|
||||||
console.error(err);
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
nodeEditorReset: () => {
|
nodeEditorReset: () => {
|
||||||
return { ...initialNodesState };
|
return { ...initialNodesState };
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
extraReducers(builder) {
|
extraReducers: (builder) => {
|
||||||
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
||||||
state.schema = action.payload;
|
state.schema = action.payload;
|
||||||
});
|
});
|
||||||
@ -112,10 +101,10 @@ export const {
|
|||||||
connectionStarted,
|
connectionStarted,
|
||||||
connectionEnded,
|
connectionEnded,
|
||||||
shouldShowGraphOverlayChanged,
|
shouldShowGraphOverlayChanged,
|
||||||
parsedOpenAPISchema,
|
nodeTemplatesBuilt,
|
||||||
nodeEditorReset,
|
nodeEditorReset,
|
||||||
} = nodesSlice.actions;
|
} = nodesSlice.actions;
|
||||||
|
|
||||||
export default nodesSlice.reducer;
|
export default nodesSlice.reducer;
|
||||||
|
|
||||||
export const nodesSelecter = (state: RootState) => state.nodes;
|
export const nodesSelector = (state: RootState) => state.nodes;
|
||||||
|
@ -34,12 +34,10 @@ export type InvocationTemplate = {
|
|||||||
* Array of invocation inputs
|
* Array of invocation inputs
|
||||||
*/
|
*/
|
||||||
inputs: Record<string, InputFieldTemplate>;
|
inputs: Record<string, InputFieldTemplate>;
|
||||||
// inputs: InputField[];
|
|
||||||
/**
|
/**
|
||||||
* Array of the invocation outputs
|
* Array of the invocation outputs
|
||||||
*/
|
*/
|
||||||
outputs: Record<string, OutputFieldTemplate>;
|
outputs: Record<string, OutputFieldTemplate>;
|
||||||
// outputs: OutputField[];
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export type FieldUIConfig = {
|
export type FieldUIConfig = {
|
||||||
@ -335,7 +333,7 @@ export type TypeHints = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type InvocationSchemaExtra = {
|
export type InvocationSchemaExtra = {
|
||||||
output: OpenAPIV3.ReferenceObject; // the output of the invocation
|
output: OpenAPIV3.SchemaObject; // the output of the invocation
|
||||||
ui?: {
|
ui?: {
|
||||||
tags?: string[];
|
tags?: string[];
|
||||||
type_hints?: TypeHints;
|
type_hints?: TypeHints;
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { filter, forEach, size } from 'lodash-es';
|
import { filter } from 'lodash-es';
|
||||||
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
|
import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
|
||||||
import { NonNullableGraph } from '../types/types';
|
import { NonNullableGraph } from '../types/types';
|
||||||
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
|
import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
|
||||||
@ -19,9 +19,9 @@ export const addControlNetToLinearGraph = (
|
|||||||
(c.processorType === 'none' && Boolean(c.controlImage)))
|
(c.processorType === 'none' && Boolean(c.controlImage)))
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add ControlNet
|
if (isControlNetEnabled && Boolean(validControlNets.length)) {
|
||||||
if (isControlNetEnabled && validControlNets.length > 0) {
|
if (validControlNets.length > 1) {
|
||||||
if (size(controlNets) > 1) {
|
// We have multiple controlnets, add ControlNet collector
|
||||||
const controlNetIterateNode: CollectInvocation = {
|
const controlNetIterateNode: CollectInvocation = {
|
||||||
id: CONTROL_NET_COLLECT,
|
id: CONTROL_NET_COLLECT,
|
||||||
type: 'collect',
|
type: 'collect',
|
||||||
@ -36,29 +36,25 @@ export const addControlNetToLinearGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
forEach(controlNets, (controlNet) => {
|
validControlNets.forEach((controlNet) => {
|
||||||
const {
|
const {
|
||||||
controlNetId,
|
controlNetId,
|
||||||
isEnabled,
|
|
||||||
controlImage,
|
controlImage,
|
||||||
processedControlImage,
|
processedControlImage,
|
||||||
beginStepPct,
|
beginStepPct,
|
||||||
endStepPct,
|
endStepPct,
|
||||||
|
controlMode,
|
||||||
model,
|
model,
|
||||||
processorType,
|
processorType,
|
||||||
weight,
|
weight,
|
||||||
} = controlNet;
|
} = controlNet;
|
||||||
|
|
||||||
if (!isEnabled) {
|
|
||||||
// Skip disabled ControlNets
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const controlNetNode: ControlNetInvocation = {
|
const controlNetNode: ControlNetInvocation = {
|
||||||
id: `control_net_${controlNetId}`,
|
id: `control_net_${controlNetId}`,
|
||||||
type: 'controlnet',
|
type: 'controlnet',
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
end_step_percent: endStepPct,
|
end_step_percent: endStepPct,
|
||||||
|
control_mode: controlMode,
|
||||||
control_model: model as ControlNetInvocation['control_model'],
|
control_model: model as ControlNetInvocation['control_model'],
|
||||||
control_weight: weight,
|
control_weight: weight,
|
||||||
};
|
};
|
||||||
@ -80,7 +76,8 @@ export const addControlNetToLinearGraph = (
|
|||||||
|
|
||||||
graph.nodes[controlNetNode.id] = controlNetNode;
|
graph.nodes[controlNetNode.id] = controlNetNode;
|
||||||
|
|
||||||
if (size(controlNets) > 1) {
|
if (validControlNets.length > 1) {
|
||||||
|
// if we have multiple controlnets, link to the collector
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: controlNetNode.id, field: 'control' },
|
source: { node_id: controlNetNode.id, field: 'control' },
|
||||||
destination: {
|
destination: {
|
||||||
@ -89,6 +86,7 @@ export const addControlNetToLinearGraph = (
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
|
// otherwise, link directly to the base node
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: controlNetNode.id, field: 'control' },
|
source: { node_id: controlNetNode.id, field: 'control' },
|
||||||
destination: {
|
destination: {
|
||||||
|
@ -349,21 +349,11 @@ export const getFieldType = (
|
|||||||
|
|
||||||
if (typeHints && name in typeHints) {
|
if (typeHints && name in typeHints) {
|
||||||
rawFieldType = typeHints[name];
|
rawFieldType = typeHints[name];
|
||||||
} else if (!schemaObject.type) {
|
} else if (!schemaObject.type && schemaObject.allOf) {
|
||||||
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf
|
// if schemaObject has no type, then it should have one of allOf
|
||||||
if (schemaObject.allOf) {
|
rawFieldType =
|
||||||
rawFieldType = refObjectToFieldType(
|
(schemaObject.allOf[0] as OpenAPIV3.SchemaObject).title ??
|
||||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
'Missing Field Type';
|
||||||
);
|
|
||||||
} else if (schemaObject.anyOf) {
|
|
||||||
rawFieldType = refObjectToFieldType(
|
|
||||||
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
|
|
||||||
);
|
|
||||||
} else if (schemaObject.oneOf) {
|
|
||||||
rawFieldType = refObjectToFieldType(
|
|
||||||
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else if (schemaObject.enum) {
|
} else if (schemaObject.enum) {
|
||||||
rawFieldType = 'enum';
|
rawFieldType = 'enum';
|
||||||
} else if (schemaObject.type) {
|
} else if (schemaObject.type) {
|
||||||
|
@ -0,0 +1,153 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import {
|
||||||
|
DynamicPromptInvocation,
|
||||||
|
IterateInvocation,
|
||||||
|
NoiseInvocation,
|
||||||
|
RandomIntInvocation,
|
||||||
|
RangeOfSizeInvocation,
|
||||||
|
} from 'services/api/types';
|
||||||
|
import {
|
||||||
|
DYNAMIC_PROMPT,
|
||||||
|
ITERATE,
|
||||||
|
NOISE,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
RANDOM_INT,
|
||||||
|
RANGE_OF_SIZE,
|
||||||
|
} from './constants';
|
||||||
|
import { unset } from 'lodash-es';
|
||||||
|
|
||||||
|
export const addDynamicPromptsToGraph = (
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
state: RootState
|
||||||
|
): void => {
|
||||||
|
const { positivePrompt, iterations, seed, shouldRandomizeSeed } =
|
||||||
|
state.generation;
|
||||||
|
|
||||||
|
const {
|
||||||
|
combinatorial,
|
||||||
|
isEnabled: isDynamicPromptsEnabled,
|
||||||
|
maxPrompts,
|
||||||
|
} = state.dynamicPrompts;
|
||||||
|
|
||||||
|
if (isDynamicPromptsEnabled) {
|
||||||
|
// iteration is handled via dynamic prompts
|
||||||
|
unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt');
|
||||||
|
|
||||||
|
const dynamicPromptNode: DynamicPromptInvocation = {
|
||||||
|
id: DYNAMIC_PROMPT,
|
||||||
|
type: 'dynamic_prompt',
|
||||||
|
max_prompts: combinatorial ? maxPrompts : iterations,
|
||||||
|
combinatorial,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
};
|
||||||
|
|
||||||
|
const iterateNode: IterateInvocation = {
|
||||||
|
id: ITERATE,
|
||||||
|
type: 'iterate',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[DYNAMIC_PROMPT] = dynamicPromptNode;
|
||||||
|
graph.nodes[ITERATE] = iterateNode;
|
||||||
|
|
||||||
|
// connect dynamic prompts to compel nodes
|
||||||
|
graph.edges.push(
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: DYNAMIC_PROMPT,
|
||||||
|
field: 'prompt_collection',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'prompt',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
if (shouldRandomizeSeed) {
|
||||||
|
// Random int node to generate the starting seed
|
||||||
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
id: RANDOM_INT,
|
||||||
|
type: 'rand_int',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||||
|
|
||||||
|
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RANDOM_INT, field: 'a' },
|
||||||
|
destination: { node_id: NOISE, field: 'seed' },
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// User specified seed, so set the start of the range of size to the seed
|
||||||
|
(graph.nodes[NOISE] as NoiseInvocation).seed = seed;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
||||||
|
id: RANGE_OF_SIZE,
|
||||||
|
type: 'range_of_size',
|
||||||
|
size: iterations,
|
||||||
|
step: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
const iterateNode: IterateInvocation = {
|
||||||
|
id: ITERATE,
|
||||||
|
type: 'iterate',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[ITERATE] = iterateNode;
|
||||||
|
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: RANGE_OF_SIZE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'seed',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// handle seed
|
||||||
|
if (shouldRandomizeSeed) {
|
||||||
|
// Random int node to generate the starting seed
|
||||||
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
id: RANDOM_INT,
|
||||||
|
type: 'rand_int',
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RANDOM_INT] = randomIntNode;
|
||||||
|
|
||||||
|
// Connect random int to the start of the range of size so the range starts on the random first seed
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RANDOM_INT, field: 'a' },
|
||||||
|
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// User specified seed, so set the start of the range of size to the seed
|
||||||
|
rangeOfSizeNode.start = seed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
@ -2,6 +2,7 @@ import { RootState } from 'app/store/store';
|
|||||||
import {
|
import {
|
||||||
ImageDTO,
|
ImageDTO,
|
||||||
ImageResizeInvocation,
|
ImageResizeInvocation,
|
||||||
|
ImageToLatentsInvocation,
|
||||||
RandomIntInvocation,
|
RandomIntInvocation,
|
||||||
RangeOfSizeInvocation,
|
RangeOfSizeInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
@ -10,7 +11,7 @@ import { log } from 'app/logging/useLogger';
|
|||||||
import {
|
import {
|
||||||
ITERATE,
|
ITERATE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MODEL_LOADER,
|
PIPELINE_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
@ -24,6 +25,7 @@ import {
|
|||||||
import { set } from 'lodash-es';
|
import { set } from 'lodash-es';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -75,31 +77,19 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
id: NEGATIVE_CONDITIONING,
|
id: NEGATIVE_CONDITIONING,
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
|
||||||
type: 'range_of_size',
|
|
||||||
id: RANGE_OF_SIZE,
|
|
||||||
// seed - must be connected manually
|
|
||||||
// start: 0,
|
|
||||||
size: iterations,
|
|
||||||
step: 1,
|
|
||||||
},
|
|
||||||
[NOISE]: {
|
[NOISE]: {
|
||||||
type: 'noise',
|
type: 'noise',
|
||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[PIPELINE_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: PIPELINE_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
},
|
},
|
||||||
[ITERATE]: {
|
|
||||||
type: 'iterate',
|
|
||||||
id: ITERATE,
|
|
||||||
},
|
|
||||||
[LATENTS_TO_LATENTS]: {
|
[LATENTS_TO_LATENTS]: {
|
||||||
type: 'l2l',
|
type: 'l2l',
|
||||||
id: LATENTS_TO_LATENTS,
|
id: LATENTS_TO_LATENTS,
|
||||||
@ -120,7 +110,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
edges: [
|
edges: [
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -130,7 +120,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -140,7 +130,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -148,26 +138,6 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: RANGE_OF_SIZE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'item',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: LATENTS_TO_LATENTS,
|
node_id: LATENTS_TO_LATENTS,
|
||||||
@ -200,7 +170,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -210,7 +180,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -241,26 +211,6 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
// handle seed
|
|
||||||
if (shouldRandomizeSeed) {
|
|
||||||
// Random int node to generate the starting seed
|
|
||||||
const randomIntNode: RandomIntInvocation = {
|
|
||||||
id: RANDOM_INT,
|
|
||||||
type: 'rand_int',
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
|
||||||
|
|
||||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANDOM_INT, field: 'a' },
|
|
||||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// User specified seed, so set the start of the range of size to the seed
|
|
||||||
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle `fit`
|
// handle `fit`
|
||||||
if (initialImage.width !== width || initialImage.height !== height) {
|
if (initialImage.width !== width || initialImage.height !== height) {
|
||||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||||
@ -306,9 +256,9 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||||
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
|
||||||
image_name: initialImage.image_name,
|
image_name: initialImage.image_name,
|
||||||
});
|
};
|
||||||
|
|
||||||
// Pass the image's dimensions to the `NOISE` node
|
// Pass the image's dimensions to the `NOISE` node
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
@ -327,7 +277,10 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// add controlnet
|
// add dynamic prompts, mutating `graph`
|
||||||
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
|
||||||
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
|
@ -9,7 +9,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
|
|||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import {
|
import {
|
||||||
ITERATE,
|
ITERATE,
|
||||||
MODEL_LOADER,
|
PIPELINE_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
RANDOM_INT,
|
||||||
@ -101,9 +101,9 @@ export const buildCanvasInpaintGraph = (
|
|||||||
id: NEGATIVE_CONDITIONING,
|
id: NEGATIVE_CONDITIONING,
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[PIPELINE_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: PIPELINE_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
[RANGE_OF_SIZE]: {
|
||||||
@ -142,7 +142,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -152,7 +152,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -162,7 +162,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -172,7 +172,7 @@ export const buildCanvasInpaintGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
|
@ -4,7 +4,7 @@ import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
|
|||||||
import {
|
import {
|
||||||
ITERATE,
|
ITERATE,
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MODEL_LOADER,
|
PIPELINE_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
@ -15,6 +15,7 @@ import {
|
|||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
@ -62,13 +63,6 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
id: NEGATIVE_CONDITIONING,
|
id: NEGATIVE_CONDITIONING,
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
|
||||||
type: 'range_of_size',
|
|
||||||
id: RANGE_OF_SIZE,
|
|
||||||
// start: 0, // seed - must be connected manually
|
|
||||||
size: iterations,
|
|
||||||
step: 1,
|
|
||||||
},
|
|
||||||
[NOISE]: {
|
[NOISE]: {
|
||||||
type: 'noise',
|
type: 'noise',
|
||||||
id: NOISE,
|
id: NOISE,
|
||||||
@ -82,19 +76,15 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[PIPELINE_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: PIPELINE_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
},
|
},
|
||||||
[ITERATE]: {
|
|
||||||
type: 'iterate',
|
|
||||||
id: ITERATE,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
edges: [
|
edges: [
|
||||||
{
|
{
|
||||||
@ -119,7 +109,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -129,7 +119,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -139,7 +129,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -159,7 +149,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -167,26 +157,6 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: RANGE_OF_SIZE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'item',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: NOISE,
|
node_id: NOISE,
|
||||||
@ -200,27 +170,10 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
// handle seed
|
// add dynamic prompts, mutating `graph`
|
||||||
if (shouldRandomizeSeed) {
|
addDynamicPromptsToGraph(graph, state);
|
||||||
// Random int node to generate the starting seed
|
|
||||||
const randomIntNode: RandomIntInvocation = {
|
|
||||||
id: RANDOM_INT,
|
|
||||||
type: 'rand_int',
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
// add controlnet, mutating `graph`
|
||||||
|
|
||||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANDOM_INT, field: 'a' },
|
|
||||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// User specified seed, so set the start of the range of size to the seed
|
|
||||||
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add controlnet
|
|
||||||
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
|
@ -1,28 +1,24 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
ImageResizeInvocation,
|
ImageResizeInvocation,
|
||||||
RandomIntInvocation,
|
ImageToLatentsInvocation,
|
||||||
RangeOfSizeInvocation,
|
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import {
|
import {
|
||||||
ITERATE,
|
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MODEL_LOADER,
|
PIPELINE_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
|
||||||
RANGE_OF_SIZE,
|
|
||||||
IMAGE_TO_IMAGE_GRAPH,
|
IMAGE_TO_IMAGE_GRAPH,
|
||||||
IMAGE_TO_LATENTS,
|
IMAGE_TO_LATENTS,
|
||||||
LATENTS_TO_LATENTS,
|
LATENTS_TO_LATENTS,
|
||||||
RESIZE,
|
RESIZE,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { set } from 'lodash-es';
|
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
|
||||||
const moduleLog = log.child({ namespace: 'nodes' });
|
const moduleLog = log.child({ namespace: 'nodes' });
|
||||||
|
|
||||||
@ -44,9 +40,6 @@ export const buildLinearImageToImageGraph = (
|
|||||||
shouldFitToWidthHeight,
|
shouldFitToWidthHeight,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
iterations,
|
|
||||||
seed,
|
|
||||||
shouldRandomizeSeed,
|
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -79,31 +72,19 @@ export const buildLinearImageToImageGraph = (
|
|||||||
id: NEGATIVE_CONDITIONING,
|
id: NEGATIVE_CONDITIONING,
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
|
||||||
type: 'range_of_size',
|
|
||||||
id: RANGE_OF_SIZE,
|
|
||||||
// seed - must be connected manually
|
|
||||||
// start: 0,
|
|
||||||
size: iterations,
|
|
||||||
step: 1,
|
|
||||||
},
|
|
||||||
[NOISE]: {
|
[NOISE]: {
|
||||||
type: 'noise',
|
type: 'noise',
|
||||||
id: NOISE,
|
id: NOISE,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[PIPELINE_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: PIPELINE_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
},
|
},
|
||||||
[ITERATE]: {
|
|
||||||
type: 'iterate',
|
|
||||||
id: ITERATE,
|
|
||||||
},
|
|
||||||
[LATENTS_TO_LATENTS]: {
|
[LATENTS_TO_LATENTS]: {
|
||||||
type: 'l2l',
|
type: 'l2l',
|
||||||
id: LATENTS_TO_LATENTS,
|
id: LATENTS_TO_LATENTS,
|
||||||
@ -124,7 +105,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
edges: [
|
edges: [
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -134,7 +115,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -144,7 +125,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -152,26 +133,6 @@ export const buildLinearImageToImageGraph = (
|
|||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: RANGE_OF_SIZE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'item',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: LATENTS_TO_LATENTS,
|
node_id: LATENTS_TO_LATENTS,
|
||||||
@ -204,7 +165,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -214,7 +175,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -245,26 +206,6 @@ export const buildLinearImageToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
// handle seed
|
|
||||||
if (shouldRandomizeSeed) {
|
|
||||||
// Random int node to generate the starting seed
|
|
||||||
const randomIntNode: RandomIntInvocation = {
|
|
||||||
id: RANDOM_INT,
|
|
||||||
type: 'rand_int',
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
|
||||||
|
|
||||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANDOM_INT, field: 'a' },
|
|
||||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// User specified seed, so set the start of the range of size to the seed
|
|
||||||
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle `fit`
|
// handle `fit`
|
||||||
if (
|
if (
|
||||||
shouldFitToWidthHeight &&
|
shouldFitToWidthHeight &&
|
||||||
@ -313,9 +254,9 @@ export const buildLinearImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||||
set(graph.nodes[IMAGE_TO_LATENTS], 'image', {
|
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
|
||||||
image_name: initialImage.imageName,
|
image_name: initialImage.imageName,
|
||||||
});
|
};
|
||||||
|
|
||||||
// Pass the image's dimensions to the `NOISE` node
|
// Pass the image's dimensions to the `NOISE` node
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
@ -334,7 +275,10 @@ export const buildLinearImageToImageGraph = (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// add controlnet
|
// add dynamic prompts, mutating `graph`
|
||||||
|
addDynamicPromptsToGraph(graph, state);
|
||||||
|
|
||||||
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
|
@ -1,33 +1,20 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
BaseModelType,
|
|
||||||
RandomIntInvocation,
|
|
||||||
RangeOfSizeInvocation,
|
|
||||||
} from 'services/api/types';
|
|
||||||
import {
|
|
||||||
ITERATE,
|
|
||||||
LATENTS_TO_IMAGE,
|
LATENTS_TO_IMAGE,
|
||||||
MODEL_LOADER,
|
PIPELINE_MODEL_LOADER,
|
||||||
NEGATIVE_CONDITIONING,
|
NEGATIVE_CONDITIONING,
|
||||||
NOISE,
|
NOISE,
|
||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
RANDOM_INT,
|
|
||||||
RANGE_OF_SIZE,
|
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
TEXT_TO_LATENTS,
|
TEXT_TO_LATENTS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
|
||||||
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
type TextToImageGraphOverrides = {
|
|
||||||
width: number;
|
|
||||||
height: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const buildLinearTextToImageGraph = (
|
export const buildLinearTextToImageGraph = (
|
||||||
state: RootState,
|
state: RootState
|
||||||
overrides?: TextToImageGraphOverrides
|
|
||||||
): NonNullableGraph => {
|
): NonNullableGraph => {
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -38,9 +25,6 @@ export const buildLinearTextToImageGraph = (
|
|||||||
steps,
|
steps,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
iterations,
|
|
||||||
seed,
|
|
||||||
shouldRandomizeSeed,
|
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
const model = modelIdToPipelineModelField(modelId);
|
const model = modelIdToPipelineModelField(modelId);
|
||||||
@ -68,18 +52,11 @@ export const buildLinearTextToImageGraph = (
|
|||||||
id: NEGATIVE_CONDITIONING,
|
id: NEGATIVE_CONDITIONING,
|
||||||
prompt: negativePrompt,
|
prompt: negativePrompt,
|
||||||
},
|
},
|
||||||
[RANGE_OF_SIZE]: {
|
|
||||||
type: 'range_of_size',
|
|
||||||
id: RANGE_OF_SIZE,
|
|
||||||
// start: 0, // seed - must be connected manually
|
|
||||||
size: iterations,
|
|
||||||
step: 1,
|
|
||||||
},
|
|
||||||
[NOISE]: {
|
[NOISE]: {
|
||||||
type: 'noise',
|
type: 'noise',
|
||||||
id: NOISE,
|
id: NOISE,
|
||||||
width: overrides?.width || width,
|
width,
|
||||||
height: overrides?.height || height,
|
height,
|
||||||
},
|
},
|
||||||
[TEXT_TO_LATENTS]: {
|
[TEXT_TO_LATENTS]: {
|
||||||
type: 't2l',
|
type: 't2l',
|
||||||
@ -88,19 +65,15 @@ export const buildLinearTextToImageGraph = (
|
|||||||
scheduler,
|
scheduler,
|
||||||
steps,
|
steps,
|
||||||
},
|
},
|
||||||
[MODEL_LOADER]: {
|
[PIPELINE_MODEL_LOADER]: {
|
||||||
type: 'pipeline_model_loader',
|
type: 'pipeline_model_loader',
|
||||||
id: MODEL_LOADER,
|
id: PIPELINE_MODEL_LOADER,
|
||||||
model,
|
model,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
},
|
},
|
||||||
[ITERATE]: {
|
|
||||||
type: 'iterate',
|
|
||||||
id: ITERATE,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
edges: [
|
edges: [
|
||||||
{
|
{
|
||||||
@ -125,7 +98,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -135,7 +108,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'clip',
|
field: 'clip',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -145,7 +118,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'unet',
|
field: 'unet',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -165,7 +138,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: MODEL_LOADER,
|
node_id: PIPELINE_MODEL_LOADER,
|
||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
destination: {
|
destination: {
|
||||||
@ -173,26 +146,6 @@ export const buildLinearTextToImageGraph = (
|
|||||||
field: 'vae',
|
field: 'vae',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: RANGE_OF_SIZE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'collection',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
source: {
|
|
||||||
node_id: ITERATE,
|
|
||||||
field: 'item',
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
node_id: NOISE,
|
|
||||||
field: 'seed',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
source: {
|
source: {
|
||||||
node_id: NOISE,
|
node_id: NOISE,
|
||||||
@ -206,27 +159,10 @@ export const buildLinearTextToImageGraph = (
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
// handle seed
|
// add dynamic prompts, mutating `graph`
|
||||||
if (shouldRandomizeSeed) {
|
addDynamicPromptsToGraph(graph, state);
|
||||||
// Random int node to generate the starting seed
|
|
||||||
const randomIntNode: RandomIntInvocation = {
|
|
||||||
id: RANDOM_INT,
|
|
||||||
type: 'rand_int',
|
|
||||||
};
|
|
||||||
|
|
||||||
graph.nodes[RANDOM_INT] = randomIntNode;
|
// add controlnet, mutating `graph`
|
||||||
|
|
||||||
// Connect random int to the start of the range of size so the range starts on the random first seed
|
|
||||||
graph.edges.push({
|
|
||||||
source: { node_id: RANDOM_INT, field: 'a' },
|
|
||||||
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
// User specified seed, so set the start of the range of size to the seed
|
|
||||||
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add controlnet
|
|
||||||
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
|
@ -7,12 +7,13 @@ export const NOISE = 'noise';
|
|||||||
export const RANDOM_INT = 'rand_int';
|
export const RANDOM_INT = 'rand_int';
|
||||||
export const RANGE_OF_SIZE = 'range_of_size';
|
export const RANGE_OF_SIZE = 'range_of_size';
|
||||||
export const ITERATE = 'iterate';
|
export const ITERATE = 'iterate';
|
||||||
export const MODEL_LOADER = 'pipeline_model_loader';
|
export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
|
||||||
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
export const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||||
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
export const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||||
export const RESIZE = 'resize_image';
|
export const RESIZE = 'resize_image';
|
||||||
export const INPAINT = 'inpaint';
|
export const INPAINT = 'inpaint';
|
||||||
export const CONTROL_NET_COLLECT = 'control_net_collect';
|
export const CONTROL_NET_COLLECT = 'control_net_collect';
|
||||||
|
export const DYNAMIC_PROMPT = 'dynamic_prompt';
|
||||||
|
|
||||||
// friendly graph ids
|
// friendly graph ids
|
||||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { CompelInvocation } from 'services/api/types';
|
|
||||||
import { O } from 'ts-toolbelt';
|
|
||||||
|
|
||||||
export const buildCompelNode = (
|
|
||||||
prompt: string,
|
|
||||||
state: RootState,
|
|
||||||
overrides: O.Partial<CompelInvocation, 'deep'> = {}
|
|
||||||
): CompelInvocation => {
|
|
||||||
const nodeId = uuidv4();
|
|
||||||
const { generation } = state;
|
|
||||||
|
|
||||||
const { model } = generation;
|
|
||||||
|
|
||||||
const compelNode: CompelInvocation = {
|
|
||||||
id: nodeId,
|
|
||||||
type: 'compel',
|
|
||||||
prompt,
|
|
||||||
model,
|
|
||||||
};
|
|
||||||
|
|
||||||
Object.assign(compelNode, overrides);
|
|
||||||
|
|
||||||
return compelNode;
|
|
||||||
};
|
|
@ -1,107 +0,0 @@
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import {
|
|
||||||
Edge,
|
|
||||||
ImageToImageInvocation,
|
|
||||||
TextToImageInvocation,
|
|
||||||
} from 'services/api/types';
|
|
||||||
import { O } from 'ts-toolbelt';
|
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
|
||||||
|
|
||||||
export const buildImg2ImgNode = (
|
|
||||||
state: RootState,
|
|
||||||
overrides: O.Partial<ImageToImageInvocation, 'deep'> = {}
|
|
||||||
): ImageToImageInvocation => {
|
|
||||||
const nodeId = uuidv4();
|
|
||||||
const { generation } = state;
|
|
||||||
|
|
||||||
const activeTabName = activeTabNameSelector(state);
|
|
||||||
|
|
||||||
const {
|
|
||||||
positivePrompt: prompt,
|
|
||||||
negativePrompt: negativePrompt,
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
cfgScale,
|
|
||||||
scheduler,
|
|
||||||
model,
|
|
||||||
img2imgStrength: strength,
|
|
||||||
shouldFitToWidthHeight: fit,
|
|
||||||
shouldRandomizeSeed,
|
|
||||||
initialImage,
|
|
||||||
} = generation;
|
|
||||||
|
|
||||||
// const initialImage = initialImageSelector(state);
|
|
||||||
|
|
||||||
const imageToImageNode: ImageToImageInvocation = {
|
|
||||||
id: nodeId,
|
|
||||||
type: 'img2img',
|
|
||||||
prompt: `${prompt} [${negativePrompt}]`,
|
|
||||||
steps,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
cfg_scale: cfgScale,
|
|
||||||
scheduler,
|
|
||||||
model,
|
|
||||||
strength,
|
|
||||||
fit,
|
|
||||||
};
|
|
||||||
|
|
||||||
// on Canvas tab, we do not manually specific init image
|
|
||||||
if (activeTabName !== 'unifiedCanvas') {
|
|
||||||
if (!initialImage) {
|
|
||||||
// TODO: handle this more better
|
|
||||||
throw 'no initial image';
|
|
||||||
}
|
|
||||||
|
|
||||||
imageToImageNode.image = {
|
|
||||||
image_name: initialImage.imageName,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!shouldRandomizeSeed) {
|
|
||||||
imageToImageNode.seed = seed;
|
|
||||||
}
|
|
||||||
|
|
||||||
Object.assign(imageToImageNode, overrides);
|
|
||||||
|
|
||||||
return imageToImageNode;
|
|
||||||
};
|
|
||||||
|
|
||||||
type hiresReturnType = {
|
|
||||||
node: Record<string, ImageToImageInvocation>;
|
|
||||||
edge: Edge;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const buildHiResNode = (
|
|
||||||
baseNode: Record<string, TextToImageInvocation>,
|
|
||||||
strength?: number
|
|
||||||
): hiresReturnType => {
|
|
||||||
const nodeId = uuidv4();
|
|
||||||
const baseNodeId = Object.keys(baseNode)[0];
|
|
||||||
const baseNodeValues = Object.values(baseNode)[0];
|
|
||||||
|
|
||||||
return {
|
|
||||||
node: {
|
|
||||||
[nodeId]: {
|
|
||||||
...baseNodeValues,
|
|
||||||
id: nodeId,
|
|
||||||
type: 'img2img',
|
|
||||||
strength,
|
|
||||||
fit: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
edge: {
|
|
||||||
source: {
|
|
||||||
field: 'image',
|
|
||||||
node_id: baseNodeId,
|
|
||||||
},
|
|
||||||
destination: {
|
|
||||||
field: 'image',
|
|
||||||
node_id: nodeId,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
};
|
|
@ -1,48 +0,0 @@
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { InpaintInvocation } from 'services/api/types';
|
|
||||||
import { O } from 'ts-toolbelt';
|
|
||||||
|
|
||||||
export const buildInpaintNode = (
|
|
||||||
state: RootState,
|
|
||||||
overrides: O.Partial<InpaintInvocation, 'deep'> = {}
|
|
||||||
): InpaintInvocation => {
|
|
||||||
const nodeId = uuidv4();
|
|
||||||
|
|
||||||
const {
|
|
||||||
positivePrompt: prompt,
|
|
||||||
negativePrompt: negativePrompt,
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
cfgScale,
|
|
||||||
scheduler,
|
|
||||||
model,
|
|
||||||
img2imgStrength: strength,
|
|
||||||
shouldFitToWidthHeight: fit,
|
|
||||||
shouldRandomizeSeed,
|
|
||||||
} = state.generation;
|
|
||||||
|
|
||||||
const inpaintNode: InpaintInvocation = {
|
|
||||||
id: nodeId,
|
|
||||||
type: 'inpaint',
|
|
||||||
prompt: `${prompt} [${negativePrompt}]`,
|
|
||||||
steps,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
cfg_scale: cfgScale,
|
|
||||||
scheduler,
|
|
||||||
model,
|
|
||||||
strength,
|
|
||||||
fit,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (!shouldRandomizeSeed) {
|
|
||||||
inpaintNode.seed = seed;
|
|
||||||
}
|
|
||||||
|
|
||||||
Object.assign(inpaintNode, overrides);
|
|
||||||
|
|
||||||
return inpaintNode;
|
|
||||||
};
|
|
@ -1,13 +0,0 @@
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
|
|
||||||
import { IterateInvocation } from 'services/api/types';
|
|
||||||
|
|
||||||
export const buildIterateNode = (): IterateInvocation => {
|
|
||||||
const nodeId = uuidv4();
|
|
||||||
return {
|
|
||||||
id: nodeId,
|
|
||||||
type: 'iterate',
|
|
||||||
// collection: [],
|
|
||||||
// index: 0,
|
|
||||||
};
|
|
||||||
};
|
|
@ -1,26 +0,0 @@
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { RandomRangeInvocation, RangeInvocation } from 'services/api/types';
|
|
||||||
|
|
||||||
export const buildRangeNode = (
|
|
||||||
state: RootState
|
|
||||||
): RangeInvocation | RandomRangeInvocation => {
|
|
||||||
const nodeId = uuidv4();
|
|
||||||
const { shouldRandomizeSeed, iterations, seed } = state.generation;
|
|
||||||
|
|
||||||
if (shouldRandomizeSeed) {
|
|
||||||
return {
|
|
||||||
id: nodeId,
|
|
||||||
type: 'random_range',
|
|
||||||
size: iterations,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
id: nodeId,
|
|
||||||
type: 'range',
|
|
||||||
start: seed,
|
|
||||||
stop: seed + iterations,
|
|
||||||
};
|
|
||||||
};
|
|
@ -1,45 +0,0 @@
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
|
||||||
import { RootState } from 'app/store/store';
|
|
||||||
import { TextToImageInvocation } from 'services/api/types';
|
|
||||||
import { O } from 'ts-toolbelt';
|
|
||||||
|
|
||||||
export const buildTxt2ImgNode = (
|
|
||||||
state: RootState,
|
|
||||||
overrides: O.Partial<TextToImageInvocation, 'deep'> = {}
|
|
||||||
): TextToImageInvocation => {
|
|
||||||
const nodeId = uuidv4();
|
|
||||||
const { generation } = state;
|
|
||||||
|
|
||||||
const {
|
|
||||||
positivePrompt: prompt,
|
|
||||||
negativePrompt: negativePrompt,
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
cfgScale: cfg_scale,
|
|
||||||
scheduler,
|
|
||||||
shouldRandomizeSeed,
|
|
||||||
model,
|
|
||||||
} = generation;
|
|
||||||
|
|
||||||
const textToImageNode: NonNullable<TextToImageInvocation> = {
|
|
||||||
id: nodeId,
|
|
||||||
type: 'txt2img',
|
|
||||||
prompt: `${prompt} [${negativePrompt}]`,
|
|
||||||
steps,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
cfg_scale,
|
|
||||||
scheduler,
|
|
||||||
model,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (!shouldRandomizeSeed) {
|
|
||||||
textToImageNode.seed = seed;
|
|
||||||
}
|
|
||||||
|
|
||||||
Object.assign(textToImageNode, overrides);
|
|
||||||
|
|
||||||
return textToImageNode;
|
|
||||||
};
|
|
@ -5,127 +5,154 @@ import {
|
|||||||
InputFieldTemplate,
|
InputFieldTemplate,
|
||||||
InvocationSchemaObject,
|
InvocationSchemaObject,
|
||||||
InvocationTemplate,
|
InvocationTemplate,
|
||||||
isInvocationSchemaObject,
|
|
||||||
OutputFieldTemplate,
|
OutputFieldTemplate,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
import {
|
import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
|
||||||
buildInputFieldTemplate,
|
import { O } from 'ts-toolbelt';
|
||||||
buildOutputFieldTemplates,
|
|
||||||
} from './fieldTemplateBuilders';
|
// recursively exclude all properties of type U from T
|
||||||
|
type DeepExclude<T, U> = T extends U
|
||||||
|
? never
|
||||||
|
: T extends object
|
||||||
|
? {
|
||||||
|
[K in keyof T]: DeepExclude<T[K], U>;
|
||||||
|
}
|
||||||
|
: T;
|
||||||
|
|
||||||
|
// The schema from swagger-parser is dereferenced, and we know `components` and `components.schemas` exist
|
||||||
|
type DereferencedOpenAPIDocument = DeepExclude<
|
||||||
|
O.Required<OpenAPIV3.Document, 'schemas' | 'components', 'deep'>,
|
||||||
|
OpenAPIV3.ReferenceObject
|
||||||
|
>;
|
||||||
|
|
||||||
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
|
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
|
||||||
|
|
||||||
const invocationDenylist = ['Graph', 'InvocationMeta'];
|
const invocationDenylist = ['Graph', 'InvocationMeta'];
|
||||||
|
|
||||||
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
const nodeFilter = (
|
||||||
|
schema: DereferencedOpenAPIDocument['components']['schemas'][string],
|
||||||
|
key: string
|
||||||
|
) =>
|
||||||
|
key.includes('Invocation') &&
|
||||||
|
!key.includes('InvocationOutput') &&
|
||||||
|
!invocationDenylist.some((denylistItem) => key.includes(denylistItem));
|
||||||
|
|
||||||
|
export const parseSchema = (openAPI: DereferencedOpenAPIDocument) => {
|
||||||
// filter out non-invocation schemas, plus some tricky invocations for now
|
// filter out non-invocation schemas, plus some tricky invocations for now
|
||||||
const filteredSchemas = filter(
|
const filteredSchemas = filter(openAPI.components.schemas, nodeFilter);
|
||||||
openAPI.components!.schemas,
|
|
||||||
(schema, key) =>
|
|
||||||
key.includes('Invocation') &&
|
|
||||||
!key.includes('InvocationOutput') &&
|
|
||||||
!invocationDenylist.some((denylistItem) => key.includes(denylistItem))
|
|
||||||
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
|
|
||||||
|
|
||||||
const invocations = filteredSchemas.reduce<
|
const invocations = filteredSchemas.reduce<
|
||||||
Record<string, InvocationTemplate>
|
Record<string, InvocationTemplate>
|
||||||
>((acc, schema) => {
|
>((acc, s) => {
|
||||||
// only want SchemaObjects
|
// cast to InvocationSchemaObject, we know the shape
|
||||||
if (isInvocationSchemaObject(schema)) {
|
const schema = s as InvocationSchemaObject;
|
||||||
const type = schema.properties.type.default;
|
|
||||||
|
|
||||||
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
|
const type = schema.properties.type.default;
|
||||||
|
|
||||||
const typeHints = schema.ui?.type_hints;
|
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
|
||||||
|
|
||||||
const inputs: Record<string, InputFieldTemplate> = {};
|
const typeHints = schema.ui?.type_hints;
|
||||||
|
|
||||||
if (type === 'collect') {
|
const inputs: Record<string, InputFieldTemplate> = {};
|
||||||
const itemProperty = schema.properties[
|
|
||||||
'item'
|
|
||||||
] as InvocationSchemaObject;
|
|
||||||
// Handle the special Collect node
|
|
||||||
inputs.item = {
|
|
||||||
type: 'item',
|
|
||||||
name: 'item',
|
|
||||||
description: itemProperty.description ?? '',
|
|
||||||
title: 'Collection Item',
|
|
||||||
inputKind: 'connection',
|
|
||||||
inputRequirement: 'always',
|
|
||||||
default: undefined,
|
|
||||||
};
|
|
||||||
} else if (type === 'iterate') {
|
|
||||||
const itemProperty = schema.properties[
|
|
||||||
'collection'
|
|
||||||
] as InvocationSchemaObject;
|
|
||||||
|
|
||||||
inputs.collection = {
|
if (type === 'collect') {
|
||||||
type: 'array',
|
// Special handling for the Collect node
|
||||||
name: 'collection',
|
const itemProperty = schema.properties['item'] as InvocationSchemaObject;
|
||||||
title: itemProperty.title ?? '',
|
inputs.item = {
|
||||||
default: [],
|
type: 'item',
|
||||||
description: itemProperty.description ?? '',
|
name: 'item',
|
||||||
inputRequirement: 'always',
|
description: itemProperty.description ?? '',
|
||||||
inputKind: 'connection',
|
title: 'Collection Item',
|
||||||
};
|
inputKind: 'connection',
|
||||||
} else {
|
inputRequirement: 'always',
|
||||||
// All other nodes
|
default: undefined,
|
||||||
reduce(
|
|
||||||
schema.properties,
|
|
||||||
(inputsAccumulator, property, propertyName) => {
|
|
||||||
if (
|
|
||||||
// `type` and `id` are not valid inputs/outputs
|
|
||||||
!RESERVED_FIELD_NAMES.includes(propertyName) &&
|
|
||||||
isSchemaObject(property)
|
|
||||||
) {
|
|
||||||
const field: InputFieldTemplate | undefined =
|
|
||||||
buildInputFieldTemplate(property, propertyName, typeHints);
|
|
||||||
|
|
||||||
if (field) {
|
|
||||||
inputsAccumulator[propertyName] = field;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return inputsAccumulator;
|
|
||||||
},
|
|
||||||
inputs
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const rawOutput = (schema as InvocationSchemaObject).output;
|
|
||||||
|
|
||||||
let outputs: Record<string, OutputFieldTemplate>;
|
|
||||||
|
|
||||||
// some special handling is needed for collect, iterate and range nodes
|
|
||||||
if (type === 'iterate') {
|
|
||||||
// this is guaranteed to be a SchemaObject
|
|
||||||
const iterationOutput = openAPI.components!.schemas![
|
|
||||||
'IterateInvocationOutput'
|
|
||||||
] as OpenAPIV3.SchemaObject;
|
|
||||||
|
|
||||||
outputs = {
|
|
||||||
item: {
|
|
||||||
name: 'item',
|
|
||||||
title: iterationOutput.title ?? '',
|
|
||||||
description: iterationOutput.description ?? '',
|
|
||||||
type: 'array',
|
|
||||||
},
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
outputs = buildOutputFieldTemplates(rawOutput, openAPI, typeHints);
|
|
||||||
}
|
|
||||||
|
|
||||||
const invocation: InvocationTemplate = {
|
|
||||||
title,
|
|
||||||
type,
|
|
||||||
tags: schema.ui?.tags ?? [],
|
|
||||||
description: schema.description ?? '',
|
|
||||||
inputs,
|
|
||||||
outputs,
|
|
||||||
};
|
};
|
||||||
|
} else if (type === 'iterate') {
|
||||||
|
// Special handling for the Iterate node
|
||||||
|
const itemProperty = schema.properties[
|
||||||
|
'collection'
|
||||||
|
] as InvocationSchemaObject;
|
||||||
|
|
||||||
Object.assign(acc, { [type]: invocation });
|
inputs.collection = {
|
||||||
|
type: 'array',
|
||||||
|
name: 'collection',
|
||||||
|
title: itemProperty.title ?? '',
|
||||||
|
default: [],
|
||||||
|
description: itemProperty.description ?? '',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'connection',
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// All other nodes
|
||||||
|
reduce(
|
||||||
|
schema.properties,
|
||||||
|
(inputsAccumulator, property, propertyName) => {
|
||||||
|
if (
|
||||||
|
// `type` and `id` are not valid inputs/outputs
|
||||||
|
!RESERVED_FIELD_NAMES.includes(propertyName) &&
|
||||||
|
isSchemaObject(property)
|
||||||
|
) {
|
||||||
|
const field: InputFieldTemplate | undefined =
|
||||||
|
buildInputFieldTemplate(property, propertyName, typeHints);
|
||||||
|
|
||||||
|
if (field) {
|
||||||
|
inputsAccumulator[propertyName] = field;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inputsAccumulator;
|
||||||
|
},
|
||||||
|
inputs
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let outputs: Record<string, OutputFieldTemplate>;
|
||||||
|
|
||||||
|
if (type === 'iterate') {
|
||||||
|
// Special handling for the Iterate node output
|
||||||
|
const iterationOutput =
|
||||||
|
openAPI.components.schemas['IterateInvocationOutput'];
|
||||||
|
|
||||||
|
outputs = {
|
||||||
|
item: {
|
||||||
|
name: 'item',
|
||||||
|
title: iterationOutput.title ?? '',
|
||||||
|
description: iterationOutput.description ?? '',
|
||||||
|
type: 'array',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// All other node outputs
|
||||||
|
outputs = reduce(
|
||||||
|
schema.output.properties as OpenAPIV3.SchemaObject,
|
||||||
|
(outputsAccumulator, property, propertyName) => {
|
||||||
|
if (!['type', 'id'].includes(propertyName)) {
|
||||||
|
const fieldType = getFieldType(property, propertyName, typeHints);
|
||||||
|
|
||||||
|
outputsAccumulator[propertyName] = {
|
||||||
|
name: propertyName,
|
||||||
|
title: property.title ?? '',
|
||||||
|
description: property.description ?? '',
|
||||||
|
type: fieldType,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputsAccumulator;
|
||||||
|
},
|
||||||
|
{} as Record<string, OutputFieldTemplate>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const invocation: InvocationTemplate = {
|
||||||
|
title,
|
||||||
|
type,
|
||||||
|
tags: schema.ui?.tags ?? [],
|
||||||
|
description: schema.description ?? '',
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
};
|
||||||
|
|
||||||
|
Object.assign(acc, { [type]: invocation });
|
||||||
|
|
||||||
return acc;
|
return acc;
|
||||||
}, {});
|
}, {});
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAINumberInput from 'common/components/IAINumberInput';
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
@ -10,27 +11,27 @@ import { uiSelector } from 'features/ui/store/uiSelectors';
|
|||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector([stateSelector], (state) => {
|
||||||
[generationSelector, configSelector, uiSelector, hotkeysSelector],
|
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
|
||||||
(generation, config, ui, hotkeys) => {
|
state.config.sd.iterations;
|
||||||
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
|
const { iterations } = state.generation;
|
||||||
config.sd.iterations;
|
const { shouldUseSliders } = state.ui;
|
||||||
const { iterations } = generation;
|
const isDisabled =
|
||||||
const { shouldUseSliders } = ui;
|
state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
|
||||||
|
|
||||||
const step = hotkeys.shift ? fineStep : coarseStep;
|
const step = state.hotkeys.shift ? fineStep : coarseStep;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
iterations,
|
iterations,
|
||||||
initial,
|
initial,
|
||||||
min,
|
min,
|
||||||
sliderMax,
|
sliderMax,
|
||||||
inputMax,
|
inputMax,
|
||||||
step,
|
step,
|
||||||
shouldUseSliders,
|
shouldUseSliders,
|
||||||
};
|
isDisabled,
|
||||||
}
|
};
|
||||||
);
|
});
|
||||||
|
|
||||||
const ParamIterations = () => {
|
const ParamIterations = () => {
|
||||||
const {
|
const {
|
||||||
@ -41,6 +42,7 @@ const ParamIterations = () => {
|
|||||||
inputMax,
|
inputMax,
|
||||||
step,
|
step,
|
||||||
shouldUseSliders,
|
shouldUseSliders,
|
||||||
|
isDisabled,
|
||||||
} = useAppSelector(selector);
|
} = useAppSelector(selector);
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -58,6 +60,7 @@ const ParamIterations = () => {
|
|||||||
|
|
||||||
return shouldUseSliders ? (
|
return shouldUseSliders ? (
|
||||||
<IAISlider
|
<IAISlider
|
||||||
|
isDisabled={isDisabled}
|
||||||
label={t('parameters.images')}
|
label={t('parameters.images')}
|
||||||
step={step}
|
step={step}
|
||||||
min={min}
|
min={min}
|
||||||
@ -72,6 +75,7 @@ const ParamIterations = () => {
|
|||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<IAINumberInput
|
<IAINumberInput
|
||||||
|
isDisabled={isDisabled}
|
||||||
label={t('parameters.images')}
|
label={t('parameters.images')}
|
||||||
step={step}
|
step={step}
|
||||||
min={min}
|
min={min}
|
||||||
|
@ -24,7 +24,7 @@ const ModelSelect = () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const { data: pipelineModels } = useListModelsQuery({
|
const { data: pipelineModels } = useListModelsQuery({
|
||||||
model_type: 'pipeline',
|
model_type: 'main',
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
|
@ -60,6 +60,14 @@ export const initialConfigState: AppConfig = {
|
|||||||
fineStep: 0.01,
|
fineStep: 0.01,
|
||||||
coarseStep: 0.05,
|
coarseStep: 0.05,
|
||||||
},
|
},
|
||||||
|
dynamicPrompts: {
|
||||||
|
maxPrompts: {
|
||||||
|
initial: 100,
|
||||||
|
min: 1,
|
||||||
|
sliderMax: 1000,
|
||||||
|
inputMax: 10000,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import * as InvokeAI from 'app/types/invokeai';
|
|||||||
|
|
||||||
import { InvokeLogLevel } from 'app/logging/useLogger';
|
import { InvokeLogLevel } from 'app/logging/useLogger';
|
||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { TFuncKey, t } from 'i18next';
|
import { TFuncKey, t } from 'i18next';
|
||||||
import { LogLevelName } from 'roarr';
|
import { LogLevelName } from 'roarr';
|
||||||
import {
|
import {
|
||||||
@ -26,6 +25,7 @@ import {
|
|||||||
} from 'services/api/thunks/session';
|
} from 'services/api/thunks/session';
|
||||||
import { makeToast } from '../../../app/components/Toaster';
|
import { makeToast } from '../../../app/components/Toaster';
|
||||||
import { LANGUAGES } from '../components/LanguagePicker';
|
import { LANGUAGES } from '../components/LanguagePicker';
|
||||||
|
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||||
|
|
||||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||||
|
|
||||||
@ -382,7 +382,7 @@ export const systemSlice = createSlice({
|
|||||||
/**
|
/**
|
||||||
* OpenAPI schema was parsed
|
* OpenAPI schema was parsed
|
||||||
*/
|
*/
|
||||||
builder.addCase(parsedOpenAPISchema, (state) => {
|
builder.addCase(nodeTemplatesBuilt, (state) => {
|
||||||
state.wasSchemaParsed = true;
|
state.wasSchemaParsed = true;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -17,7 +17,6 @@ import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
|||||||
import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
|
import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
import { useHotkeys } from 'react-hotkeys-hook';
|
||||||
import { MdDeviceHub, MdGridOn } from 'react-icons/md';
|
import { MdDeviceHub, MdGridOn } from 'react-icons/md';
|
||||||
import { GoTextSize } from 'react-icons/go';
|
|
||||||
import {
|
import {
|
||||||
activeTabIndexSelector,
|
activeTabIndexSelector,
|
||||||
activeTabNameSelector,
|
activeTabNameSelector,
|
||||||
@ -33,7 +32,7 @@ import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent
|
|||||||
import TextToImageTab from './tabs/TextToImage/TextToImageTab';
|
import TextToImageTab from './tabs/TextToImage/TextToImageTab';
|
||||||
import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab';
|
import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab';
|
||||||
import NodesTab from './tabs/Nodes/NodesTab';
|
import NodesTab from './tabs/Nodes/NodesTab';
|
||||||
import { FaImage } from 'react-icons/fa';
|
import { FaFont, FaImage } from 'react-icons/fa';
|
||||||
import ResizeHandle from './tabs/ResizeHandle';
|
import ResizeHandle from './tabs/ResizeHandle';
|
||||||
import ImageTab from './tabs/ImageToImage/ImageToImageTab';
|
import ImageTab from './tabs/ImageToImage/ImageToImageTab';
|
||||||
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
|
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
|
||||||
@ -47,7 +46,7 @@ export interface InvokeTabInfo {
|
|||||||
const tabs: InvokeTabInfo[] = [
|
const tabs: InvokeTabInfo[] = [
|
||||||
{
|
{
|
||||||
id: 'txt2img',
|
id: 'txt2img',
|
||||||
icon: <Icon as={GoTextSize} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
icon: <Icon as={FaFont} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||||
content: <TextToImageTab />,
|
content: <TextToImageTab />,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -8,6 +8,7 @@ import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Sym
|
|||||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||||
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
|
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
|
||||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||||
|
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||||
|
|
||||||
const ImageToImageTabParameters = () => {
|
const ImageToImageTabParameters = () => {
|
||||||
return (
|
return (
|
||||||
@ -16,6 +17,7 @@ const ImageToImageTabParameters = () => {
|
|||||||
<ParamNegativeConditioning />
|
<ParamNegativeConditioning />
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<ImageToImageTabCoreParameters />
|
<ImageToImageTabCoreParameters />
|
||||||
|
<ParamDynamicPromptsCollapse />
|
||||||
<ParamControlNetCollapse />
|
<ParamControlNetCollapse />
|
||||||
<ParamVariationCollapse />
|
<ParamVariationCollapse />
|
||||||
<ParamNoiseCollapse />
|
<ParamNoiseCollapse />
|
||||||
|
@ -9,6 +9,7 @@ import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/
|
|||||||
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
|
||||||
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
|
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
|
||||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||||
|
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||||
|
|
||||||
const TextToImageTabParameters = () => {
|
const TextToImageTabParameters = () => {
|
||||||
return (
|
return (
|
||||||
@ -17,6 +18,7 @@ const TextToImageTabParameters = () => {
|
|||||||
<ParamNegativeConditioning />
|
<ParamNegativeConditioning />
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<TextToImageTabCoreParameters />
|
<TextToImageTabCoreParameters />
|
||||||
|
<ParamDynamicPromptsCollapse />
|
||||||
<ParamControlNetCollapse />
|
<ParamControlNetCollapse />
|
||||||
<ParamVariationCollapse />
|
<ParamVariationCollapse />
|
||||||
<ParamNoiseCollapse />
|
<ParamNoiseCollapse />
|
||||||
|
@ -8,6 +8,7 @@ import { memo } from 'react';
|
|||||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||||
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
|
||||||
|
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||||
|
|
||||||
const UnifiedCanvasParameters = () => {
|
const UnifiedCanvasParameters = () => {
|
||||||
return (
|
return (
|
||||||
@ -16,6 +17,7 @@ const UnifiedCanvasParameters = () => {
|
|||||||
<ParamNegativeConditioning />
|
<ParamNegativeConditioning />
|
||||||
<ProcessButtons />
|
<ProcessButtons />
|
||||||
<UnifiedCanvasCoreParameters />
|
<UnifiedCanvasCoreParameters />
|
||||||
|
<ParamDynamicPromptsCollapse />
|
||||||
<ParamControlNetCollapse />
|
<ParamControlNetCollapse />
|
||||||
<ParamVariationCollapse />
|
<ParamVariationCollapse />
|
||||||
<ParamSymmetryCollapse />
|
<ParamSymmetryCollapse />
|
||||||
|
@ -15,6 +15,7 @@ export const imagesApi = api.injectEndpoints({
|
|||||||
}
|
}
|
||||||
return tags;
|
return tags;
|
||||||
},
|
},
|
||||||
|
keepUnusedDataFor: 86400, // 24 hours
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
@ -0,0 +1,36 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageField } from './ImageField';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies HED edge detection to image
|
||||||
|
*/
|
||||||
|
export type HedImageProcessorInvocation = {
|
||||||
|
/**
|
||||||
|
* The id of this node. Must be unique among all nodes.
|
||||||
|
*/
|
||||||
|
id: string;
|
||||||
|
/**
|
||||||
|
* Whether or not this node is an intermediate node.
|
||||||
|
*/
|
||||||
|
is_intermediate?: boolean;
|
||||||
|
type?: 'hed_image_processor';
|
||||||
|
/**
|
||||||
|
* The image to process
|
||||||
|
*/
|
||||||
|
image?: ImageField;
|
||||||
|
/**
|
||||||
|
* The pixel resolution for detection
|
||||||
|
*/
|
||||||
|
detect_resolution?: number;
|
||||||
|
/**
|
||||||
|
* The pixel resolution for the output image
|
||||||
|
*/
|
||||||
|
image_resolution?: number;
|
||||||
|
/**
|
||||||
|
* Whether to use scribble mode
|
||||||
|
*/
|
||||||
|
scribble?: boolean;
|
||||||
|
};
|
@ -648,6 +648,13 @@ export type components = {
|
|||||||
* @default 1
|
* @default 1
|
||||||
*/
|
*/
|
||||||
end_step_percent: number;
|
end_step_percent: number;
|
||||||
|
/**
|
||||||
|
* Control Mode
|
||||||
|
* @description The contorl mode to use
|
||||||
|
* @default balanced
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* ControlNetInvocation
|
* ControlNetInvocation
|
||||||
@ -701,6 +708,13 @@ export type components = {
|
|||||||
* @default 1
|
* @default 1
|
||||||
*/
|
*/
|
||||||
end_step_percent?: number;
|
end_step_percent?: number;
|
||||||
|
/**
|
||||||
|
* Control Mode
|
||||||
|
* @description The control mode used
|
||||||
|
* @default balanced
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
|
||||||
};
|
};
|
||||||
/** ControlNetModelConfig */
|
/** ControlNetModelConfig */
|
||||||
ControlNetModelConfig: {
|
ControlNetModelConfig: {
|
||||||
@ -1016,7 +1030,7 @@ export type components = {
|
|||||||
* @description The nodes in this graph
|
* @description The nodes in this graph
|
||||||
*/
|
*/
|
||||||
nodes?: {
|
nodes?: {
|
||||||
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Edges
|
* Edges
|
||||||
@ -1059,7 +1073,7 @@ export type components = {
|
|||||||
* @description The results of node executions
|
* @description The results of node executions
|
||||||
*/
|
*/
|
||||||
results: {
|
results: {
|
||||||
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Errors
|
* Errors
|
||||||
@ -2903,7 +2917,7 @@ export type components = {
|
|||||||
/** ModelsList */
|
/** ModelsList */
|
||||||
ModelsList: {
|
ModelsList: {
|
||||||
/** Models */
|
/** Models */
|
||||||
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
|
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* MultiplyInvocation
|
* MultiplyInvocation
|
||||||
@ -2979,6 +2993,18 @@ export type components = {
|
|||||||
* @default 512
|
* @default 512
|
||||||
*/
|
*/
|
||||||
height?: number;
|
height?: number;
|
||||||
|
/**
|
||||||
|
* Perlin
|
||||||
|
* @description The amount of perlin noise to add to the noise
|
||||||
|
* @default 0
|
||||||
|
*/
|
||||||
|
perlin?: number;
|
||||||
|
/**
|
||||||
|
* Use Cpu
|
||||||
|
* @description Use CPU for noise generation (for reproducible results across platforms)
|
||||||
|
* @default true
|
||||||
|
*/
|
||||||
|
use_cpu?: boolean;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* NoiseOutput
|
* NoiseOutput
|
||||||
@ -4163,18 +4189,18 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
image?: components["schemas"]["ImageField"];
|
image?: components["schemas"]["ImageField"];
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* StableDiffusion1ModelFormat
|
|
||||||
* @description An enumeration.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
|
||||||
/**
|
/**
|
||||||
* StableDiffusion2ModelFormat
|
* StableDiffusion2ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||||
|
/**
|
||||||
|
* StableDiffusion1ModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||||
};
|
};
|
||||||
responses: never;
|
responses: never;
|
||||||
parameters: never;
|
parameters: never;
|
||||||
@ -4285,7 +4311,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
@ -4322,7 +4348,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
requestBody: {
|
requestBody: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
responses: {
|
responses: {
|
||||||
|
@ -1,20 +1,45 @@
|
|||||||
|
import SwaggerParser from '@apidevtools/swagger-parser';
|
||||||
import { createAsyncThunk } from '@reduxjs/toolkit';
|
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||||
import { log } from 'app/logging/useLogger';
|
import { log } from 'app/logging/useLogger';
|
||||||
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
|
|
||||||
const schemaLog = log.child({ namespace: 'schema' });
|
const schemaLog = log.child({ namespace: 'schema' });
|
||||||
|
|
||||||
|
function getCircularReplacer() {
|
||||||
|
const ancestors: Record<string, any>[] = [];
|
||||||
|
return function (key: string, value: any) {
|
||||||
|
if (typeof value !== 'object' || value === null) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
// `this` is the object that value is contained in,
|
||||||
|
// i.e., its direct parent.
|
||||||
|
// @ts-ignore
|
||||||
|
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
|
||||||
|
ancestors.pop();
|
||||||
|
}
|
||||||
|
if (ancestors.includes(value)) {
|
||||||
|
return '[Circular]';
|
||||||
|
}
|
||||||
|
ancestors.push(value);
|
||||||
|
return value;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export const receivedOpenAPISchema = createAsyncThunk(
|
export const receivedOpenAPISchema = createAsyncThunk(
|
||||||
'nodes/receivedOpenAPISchema',
|
'nodes/receivedOpenAPISchema',
|
||||||
async (_, { dispatch }): Promise<OpenAPIV3.Document> => {
|
async (_, { dispatch, rejectWithValue }) => {
|
||||||
const response = await fetch(`openapi.json`);
|
try {
|
||||||
const openAPISchema = await response.json();
|
const dereferencedSchema = (await SwaggerParser.dereference(
|
||||||
|
'openapi.json'
|
||||||
|
)) as OpenAPIV3.Document;
|
||||||
|
|
||||||
schemaLog.info({ openAPISchema }, 'Received OpenAPI schema');
|
const schemaJSON = JSON.parse(
|
||||||
|
JSON.stringify(dereferencedSchema, getCircularReplacer())
|
||||||
|
);
|
||||||
|
|
||||||
dispatch(parsedOpenAPISchema(openAPISchema as OpenAPIV3.Document));
|
return schemaJSON;
|
||||||
|
} catch (error) {
|
||||||
return openAPISchema;
|
return rejectWithValue({ error });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
121
invokeai/frontend/web/src/services/api/types.d.ts
vendored
121
invokeai/frontend/web/src/services/api/types.d.ts
vendored
@ -1,81 +1,92 @@
|
|||||||
|
import { O } from 'ts-toolbelt';
|
||||||
import { components } from './schema';
|
import { components } from './schema';
|
||||||
|
|
||||||
|
type schemas = components['schemas'];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Types from the API, re-exported from the types generated by `openapi-typescript`.
|
* Extracts the schema type from the schema.
|
||||||
*/
|
*/
|
||||||
|
type S<T extends keyof components['schemas']> = components['schemas'][T];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts the node type from the schema.
|
||||||
|
* Also flags the `type` property as required.
|
||||||
|
*/
|
||||||
|
type N<T extends keyof components['schemas']> = O.Required<
|
||||||
|
components['schemas'][T],
|
||||||
|
'type'
|
||||||
|
>;
|
||||||
|
|
||||||
// Images
|
// Images
|
||||||
export type ImageDTO = components['schemas']['ImageDTO'];
|
export type ImageDTO = S<'ImageDTO'>;
|
||||||
export type BoardDTO = components['schemas']['BoardDTO'];
|
export type BoardDTO = S<'BoardDTO'>;
|
||||||
export type BoardChanges = components['schemas']['BoardChanges'];
|
export type BoardChanges = S<'BoardChanges'>;
|
||||||
export type ImageChanges = components['schemas']['ImageRecordChanges'];
|
export type ImageChanges = S<'ImageRecordChanges'>;
|
||||||
export type ImageCategory = components['schemas']['ImageCategory'];
|
export type ImageCategory = S<'ImageCategory'>;
|
||||||
export type ResourceOrigin = components['schemas']['ResourceOrigin'];
|
export type ResourceOrigin = S<'ResourceOrigin'>;
|
||||||
export type ImageField = components['schemas']['ImageField'];
|
export type ImageField = S<'ImageField'>;
|
||||||
export type OffsetPaginatedResults_BoardDTO_ =
|
export type OffsetPaginatedResults_BoardDTO_ =
|
||||||
components['schemas']['OffsetPaginatedResults_BoardDTO_'];
|
S<'OffsetPaginatedResults_BoardDTO_'>;
|
||||||
export type OffsetPaginatedResults_ImageDTO_ =
|
export type OffsetPaginatedResults_ImageDTO_ =
|
||||||
components['schemas']['OffsetPaginatedResults_ImageDTO_'];
|
S<'OffsetPaginatedResults_ImageDTO_'>;
|
||||||
|
|
||||||
// Models
|
// Models
|
||||||
export type ModelType = components['schemas']['ModelType'];
|
export type ModelType = S<'ModelType'>;
|
||||||
export type BaseModelType = components['schemas']['BaseModelType'];
|
export type BaseModelType = S<'BaseModelType'>;
|
||||||
export type PipelineModelField = components['schemas']['PipelineModelField'];
|
export type PipelineModelField = S<'PipelineModelField'>;
|
||||||
export type ModelsList = components['schemas']['ModelsList'];
|
export type ModelsList = S<'ModelsList'>;
|
||||||
|
|
||||||
// Graphs
|
// Graphs
|
||||||
export type Graph = components['schemas']['Graph'];
|
export type Graph = S<'Graph'>;
|
||||||
export type Edge = components['schemas']['Edge'];
|
export type Edge = S<'Edge'>;
|
||||||
export type GraphExecutionState = components['schemas']['GraphExecutionState'];
|
export type GraphExecutionState = S<'GraphExecutionState'>;
|
||||||
|
|
||||||
// General nodes
|
// General nodes
|
||||||
export type CollectInvocation = components['schemas']['CollectInvocation'];
|
export type CollectInvocation = N<'CollectInvocation'>;
|
||||||
export type IterateInvocation = components['schemas']['IterateInvocation'];
|
export type IterateInvocation = N<'IterateInvocation'>;
|
||||||
export type RangeInvocation = components['schemas']['RangeInvocation'];
|
export type RangeInvocation = N<'RangeInvocation'>;
|
||||||
export type RandomRangeInvocation =
|
export type RandomRangeInvocation = N<'RandomRangeInvocation'>;
|
||||||
components['schemas']['RandomRangeInvocation'];
|
export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>;
|
||||||
export type RangeOfSizeInvocation =
|
export type InpaintInvocation = N<'InpaintInvocation'>;
|
||||||
components['schemas']['RangeOfSizeInvocation'];
|
export type ImageResizeInvocation = N<'ImageResizeInvocation'>;
|
||||||
export type InpaintInvocation = components['schemas']['InpaintInvocation'];
|
export type RandomIntInvocation = N<'RandomIntInvocation'>;
|
||||||
export type ImageResizeInvocation =
|
export type CompelInvocation = N<'CompelInvocation'>;
|
||||||
components['schemas']['ImageResizeInvocation'];
|
export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>;
|
||||||
export type RandomIntInvocation = components['schemas']['RandomIntInvocation'];
|
export type NoiseInvocation = N<'NoiseInvocation'>;
|
||||||
export type CompelInvocation = components['schemas']['CompelInvocation'];
|
export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
|
||||||
|
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
|
||||||
|
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
|
||||||
|
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
|
||||||
|
export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
|
||||||
|
|
||||||
// ControlNet Nodes
|
// ControlNet Nodes
|
||||||
export type CannyImageProcessorInvocation =
|
export type ControlNetInvocation = N<'ControlNetInvocation'>;
|
||||||
components['schemas']['CannyImageProcessorInvocation'];
|
export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>;
|
||||||
export type ContentShuffleImageProcessorInvocation =
|
export type ContentShuffleImageProcessorInvocation =
|
||||||
components['schemas']['ContentShuffleImageProcessorInvocation'];
|
N<'ContentShuffleImageProcessorInvocation'>;
|
||||||
export type HedImageProcessorInvocation =
|
export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>;
|
||||||
components['schemas']['HedImageProcessorInvocation'];
|
|
||||||
export type LineartAnimeImageProcessorInvocation =
|
export type LineartAnimeImageProcessorInvocation =
|
||||||
components['schemas']['LineartAnimeImageProcessorInvocation'];
|
N<'LineartAnimeImageProcessorInvocation'>;
|
||||||
export type LineartImageProcessorInvocation =
|
export type LineartImageProcessorInvocation =
|
||||||
components['schemas']['LineartImageProcessorInvocation'];
|
N<'LineartImageProcessorInvocation'>;
|
||||||
export type MediapipeFaceProcessorInvocation =
|
export type MediapipeFaceProcessorInvocation =
|
||||||
components['schemas']['MediapipeFaceProcessorInvocation'];
|
N<'MediapipeFaceProcessorInvocation'>;
|
||||||
export type MidasDepthImageProcessorInvocation =
|
export type MidasDepthImageProcessorInvocation =
|
||||||
components['schemas']['MidasDepthImageProcessorInvocation'];
|
N<'MidasDepthImageProcessorInvocation'>;
|
||||||
export type MlsdImageProcessorInvocation =
|
export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>;
|
||||||
components['schemas']['MlsdImageProcessorInvocation'];
|
|
||||||
export type NormalbaeImageProcessorInvocation =
|
export type NormalbaeImageProcessorInvocation =
|
||||||
components['schemas']['NormalbaeImageProcessorInvocation'];
|
N<'NormalbaeImageProcessorInvocation'>;
|
||||||
export type OpenposeImageProcessorInvocation =
|
export type OpenposeImageProcessorInvocation =
|
||||||
components['schemas']['OpenposeImageProcessorInvocation'];
|
N<'OpenposeImageProcessorInvocation'>;
|
||||||
export type PidiImageProcessorInvocation =
|
export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>;
|
||||||
components['schemas']['PidiImageProcessorInvocation'];
|
|
||||||
export type ZoeDepthImageProcessorInvocation =
|
export type ZoeDepthImageProcessorInvocation =
|
||||||
components['schemas']['ZoeDepthImageProcessorInvocation'];
|
N<'ZoeDepthImageProcessorInvocation'>;
|
||||||
|
|
||||||
// Node Outputs
|
// Node Outputs
|
||||||
export type ImageOutput = components['schemas']['ImageOutput'];
|
export type ImageOutput = S<'ImageOutput'>;
|
||||||
export type MaskOutput = components['schemas']['MaskOutput'];
|
export type MaskOutput = S<'MaskOutput'>;
|
||||||
export type PromptOutput = components['schemas']['PromptOutput'];
|
export type PromptOutput = S<'PromptOutput'>;
|
||||||
export type IterateInvocationOutput =
|
export type IterateInvocationOutput = S<'IterateInvocationOutput'>;
|
||||||
components['schemas']['IterateInvocationOutput'];
|
export type CollectInvocationOutput = S<'CollectInvocationOutput'>;
|
||||||
export type CollectInvocationOutput =
|
export type LatentsOutput = S<'LatentsOutput'>;
|
||||||
components['schemas']['CollectInvocationOutput'];
|
export type GraphInvocationOutput = S<'GraphInvocationOutput'>;
|
||||||
export type LatentsOutput = components['schemas']['LatentsOutput'];
|
|
||||||
export type GraphInvocationOutput =
|
|
||||||
components['schemas']['GraphInvocationOutput'];
|
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
"vite.config.ts",
|
"vite.config.ts",
|
||||||
"./config/vite.app.config.ts",
|
"./config/vite.app.config.ts",
|
||||||
"./config/vite.package.config.ts",
|
"./config/vite.package.config.ts",
|
||||||
"./config/vite.common.config.ts"
|
"./config/vite.common.config.ts",
|
||||||
|
"./config/common.ts"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -39,7 +39,7 @@ dependencies = [
|
|||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel>=1.2.1",
|
"compel>=1.2.1",
|
||||||
"controlnet-aux>=0.0.4",
|
"controlnet-aux>=0.0.6",
|
||||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.17.1",
|
"diffusers[torch]~=0.17.1",
|
||||||
|
30
tests/conftest.py
Normal file
30
tests/conftest.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import pytest
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||||
|
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||||
|
from invokeai.app.services.graph import LibraryGraph, GraphExecutionState
|
||||||
|
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||||
|
|
||||||
|
# Ignore these files as they need to be rewritten following the model manager refactor
|
||||||
|
collect_ignore = ["nodes/test_graph_execution_state.py", "nodes/test_node_graph.py", "test_textual_inversion.py"]
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def mock_services():
|
||||||
|
# NOTE: none of these are actually called by the test invocations
|
||||||
|
return InvocationServices(
|
||||||
|
model_manager = None, # type: ignore
|
||||||
|
events = None, # type: ignore
|
||||||
|
logger = None, # type: ignore
|
||||||
|
images = None, # type: ignore
|
||||||
|
latents = None, # type: ignore
|
||||||
|
board_images=None, # type: ignore
|
||||||
|
boards=None, # type: ignore
|
||||||
|
queue = MemoryInvocationQueue(),
|
||||||
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
filename=sqlite_memory, table_name="graphs"
|
||||||
|
),
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||||
|
processor = DefaultInvocationProcessor(),
|
||||||
|
restoration = None, # type: ignore
|
||||||
|
configuration = None, # type: ignore
|
||||||
|
)
|
@ -1,14 +1,18 @@
|
|||||||
from .test_invoker import create_edge
|
import pytest
|
||||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
InvocationContext)
|
||||||
from invokeai.app.invocations.collections import RangeInvocation
|
from invokeai.app.invocations.collections import RangeInvocation
|
||||||
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
from invokeai.app.services.graph import (CollectInvocation, Graph,
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
GraphExecutionState,
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
IterateInvocation)
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
|
||||||
import pytest
|
from .test_invoker import create_edge
|
||||||
|
from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation,
|
||||||
|
PromptTestInvocation)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -19,30 +23,11 @@ def simple_graph():
|
|||||||
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_services():
|
|
||||||
# NOTE: none of these are actually called by the test invocations
|
|
||||||
return InvocationServices(
|
|
||||||
model_manager = None, # type: ignore
|
|
||||||
events = None, # type: ignore
|
|
||||||
logger = None, # type: ignore
|
|
||||||
images = None, # type: ignore
|
|
||||||
latents = None, # type: ignore
|
|
||||||
queue = MemoryInvocationQueue(),
|
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
|
||||||
filename=sqlite_memory, table_name="graphs"
|
|
||||||
),
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
|
||||||
processor = DefaultInvocationProcessor(),
|
|
||||||
restoration = None, # type: ignore
|
|
||||||
configuration = None, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||||
n = g.next()
|
n = g.next()
|
||||||
if n is None:
|
if n is None:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
print(f'invoking {n.id}: {type(n)}')
|
print(f'invoking {n.id}: {type(n)}')
|
||||||
o = n.invoke(InvocationContext(services, "1"))
|
o = n.invoke(InvocationContext(services, "1"))
|
||||||
g.complete(n.id, o)
|
g.complete(n.id, o)
|
||||||
@ -51,7 +36,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
|
|||||||
|
|
||||||
def test_graph_state_executes_in_order(simple_graph, mock_services):
|
def test_graph_state_executes_in_order(simple_graph, mock_services):
|
||||||
g = GraphExecutionState(graph = simple_graph)
|
g = GraphExecutionState(graph = simple_graph)
|
||||||
|
|
||||||
n1 = invoke_next(g, mock_services)
|
n1 = invoke_next(g, mock_services)
|
||||||
n2 = invoke_next(g, mock_services)
|
n2 = invoke_next(g, mock_services)
|
||||||
n3 = g.next()
|
n3 = g.next()
|
||||||
@ -88,11 +73,11 @@ def test_graph_state_expands_iterator(mock_services):
|
|||||||
graph.add_edge(create_edge("0", "collection", "1", "collection"))
|
graph.add_edge(create_edge("0", "collection", "1", "collection"))
|
||||||
graph.add_edge(create_edge("1", "item", "2", "a"))
|
graph.add_edge(create_edge("1", "item", "2", "a"))
|
||||||
graph.add_edge(create_edge("2", "a", "3", "a"))
|
graph.add_edge(create_edge("2", "a", "3", "a"))
|
||||||
|
|
||||||
g = GraphExecutionState(graph = graph)
|
g = GraphExecutionState(graph = graph)
|
||||||
while not g.is_complete():
|
while not g.is_complete():
|
||||||
invoke_next(g, mock_services)
|
invoke_next(g, mock_services)
|
||||||
|
|
||||||
prepared_add_nodes = g.source_prepared_mapping['3']
|
prepared_add_nodes = g.source_prepared_mapping['3']
|
||||||
results = set([g.results[n].a for n in prepared_add_nodes])
|
results = set([g.results[n].a for n in prepared_add_nodes])
|
||||||
expected = set([1, 11, 21])
|
expected = set([1, 11, 21])
|
||||||
@ -109,7 +94,7 @@ def test_graph_state_collects(mock_services):
|
|||||||
graph.add_edge(create_edge("1", "collection", "2", "collection"))
|
graph.add_edge(create_edge("1", "collection", "2", "collection"))
|
||||||
graph.add_edge(create_edge("2", "item", "3", "prompt"))
|
graph.add_edge(create_edge("2", "item", "3", "prompt"))
|
||||||
graph.add_edge(create_edge("3", "prompt", "4", "item"))
|
graph.add_edge(create_edge("3", "prompt", "4", "item"))
|
||||||
|
|
||||||
g = GraphExecutionState(graph = graph)
|
g = GraphExecutionState(graph = graph)
|
||||||
n1 = invoke_next(g, mock_services)
|
n1 = invoke_next(g, mock_services)
|
||||||
n2 = invoke_next(g, mock_services)
|
n2 = invoke_next(g, mock_services)
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
|
|
||||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
|
||||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
from .test_nodes import (ErrorInvocation, ImageTestInvocation,
|
||||||
|
PromptTestInvocation, create_edge, wait_until)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def simple_graph():
|
def simple_graph():
|
||||||
@ -17,25 +16,6 @@ def simple_graph():
|
|||||||
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_services() -> InvocationServices:
|
|
||||||
# NOTE: none of these are actually called by the test invocations
|
|
||||||
return InvocationServices(
|
|
||||||
model_manager = None, # type: ignore
|
|
||||||
events = TestEventService(),
|
|
||||||
logger = None, # type: ignore
|
|
||||||
images = None, # type: ignore
|
|
||||||
latents = None, # type: ignore
|
|
||||||
queue = MemoryInvocationQueue(),
|
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
|
||||||
filename=sqlite_memory, table_name="graphs"
|
|
||||||
),
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
|
||||||
processor = DefaultInvocationProcessor(),
|
|
||||||
restoration = None, # type: ignore
|
|
||||||
configuration = None, # type: ignore
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||||
return Invoker(
|
return Invoker(
|
||||||
@ -57,6 +37,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
|||||||
assert isinstance(g, GraphExecutionState)
|
assert isinstance(g, GraphExecutionState)
|
||||||
assert g.graph == simple_graph
|
assert g.graph == simple_graph
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||||
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||||
g = mock_invoker.create_execution_state(graph = simple_graph)
|
g = mock_invoker.create_execution_state(graph = simple_graph)
|
||||||
invocation_id = mock_invoker.invoke(g)
|
invocation_id = mock_invoker.invoke(g)
|
||||||
@ -72,6 +53,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
|||||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
assert len(g.executed) > 0
|
assert len(g.executed) > 0
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||||
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||||
g = mock_invoker.create_execution_state(graph = simple_graph)
|
g = mock_invoker.create_execution_state(graph = simple_graph)
|
||||||
invocation_id = mock_invoker.invoke(g, invoke_all = True)
|
invocation_id = mock_invoker.invoke(g, invoke_all = True)
|
||||||
@ -87,6 +69,7 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
|||||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||||
assert g.is_complete()
|
assert g.is_complete()
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||||
def test_handles_errors(mock_invoker: Invoker):
|
def test_handles_errors(mock_invoker: Invoker):
|
||||||
g = mock_invoker.create_execution_state()
|
g = mock_invoker.create_execution_state()
|
||||||
g.graph.add_node(ErrorInvocation(id = "1"))
|
g.graph.add_node(ErrorInvocation(id = "1"))
|
||||||
|
Loading…
Reference in New Issue
Block a user