Merge branch 'main' into feat/lora_model_patch

This commit is contained in:
StAlKeR7779 2023-06-28 22:43:58 +03:00 committed by GitHub
commit ac46b129bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
89 changed files with 3203 additions and 2812 deletions

View File

@ -1,10 +1,16 @@
name: Test invoke.py pip name: Test invoke.py pip
# This is a dummy stand-in for the actual tests
# we don't need to run python tests on non-Python changes
# But PRs require passing tests to be mergeable
on: on:
pull_request: pull_request:
paths: paths:
- '**' - '**'
- '!pyproject.toml' - '!pyproject.toml'
- '!invokeai/**' - '!invokeai/**'
- '!tests/**'
- 'invokeai/frontend/web/**' - 'invokeai/frontend/web/**'
merge_group: merge_group:
workflow_dispatch: workflow_dispatch:
@ -19,48 +25,26 @@ jobs:
strategy: strategy:
matrix: matrix:
python-version: python-version:
# - '3.9'
- '3.10' - '3.10'
pytorch: pytorch:
# - linux-cuda-11_6
- linux-cuda-11_7 - linux-cuda-11_7
- linux-rocm-5_2 - linux-rocm-5_2
- linux-cpu - linux-cpu
- macos-default - macos-default
- windows-cpu - windows-cpu
# - windows-cuda-11_6
# - windows-cuda-11_7
include: include:
# - pytorch: linux-cuda-11_6
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $GITHUB_ENV
- pytorch: linux-cuda-11_7 - pytorch: linux-cuda-11_7
os: ubuntu-22.04 os: ubuntu-22.04
github-env: $GITHUB_ENV
- pytorch: linux-rocm-5_2 - pytorch: linux-rocm-5_2
os: ubuntu-22.04 os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
github-env: $GITHUB_ENV
- pytorch: linux-cpu - pytorch: linux-cpu
os: ubuntu-22.04 os: ubuntu-22.04
extra-index-url: 'https://download.pytorch.org/whl/cpu'
github-env: $GITHUB_ENV
- pytorch: macos-default - pytorch: macos-default
os: macOS-12 os: macOS-12
github-env: $GITHUB_ENV
- pytorch: windows-cpu - pytorch: windows-cpu
os: windows-2022 os: windows-2022
github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_6
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_7
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
# github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }} name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- run: 'echo "No build required"' - name: skip
run: echo "no build required"

View File

@ -11,6 +11,7 @@ on:
paths: paths:
- 'pyproject.toml' - 'pyproject.toml'
- 'invokeai/**' - 'invokeai/**'
- 'tests/**'
- '!invokeai/frontend/web/**' - '!invokeai/frontend/web/**'
types: types:
- 'ready_for_review' - 'ready_for_review'
@ -32,19 +33,12 @@ jobs:
# - '3.9' # - '3.9'
- '3.10' - '3.10'
pytorch: pytorch:
# - linux-cuda-11_6
- linux-cuda-11_7 - linux-cuda-11_7
- linux-rocm-5_2 - linux-rocm-5_2
- linux-cpu - linux-cpu
- macos-default - macos-default
- windows-cpu - windows-cpu
# - windows-cuda-11_6
# - windows-cuda-11_7
include: include:
# - pytorch: linux-cuda-11_6
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $GITHUB_ENV
- pytorch: linux-cuda-11_7 - pytorch: linux-cuda-11_7
os: ubuntu-22.04 os: ubuntu-22.04
github-env: $GITHUB_ENV github-env: $GITHUB_ENV
@ -62,14 +56,6 @@ jobs:
- pytorch: windows-cpu - pytorch: windows-cpu
os: windows-2022 os: windows-2022
github-env: $env:GITHUB_ENV github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_6
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu116'
# github-env: $env:GITHUB_ENV
# - pytorch: windows-cuda-11_7
# os: windows-2022
# extra-index-url: 'https://download.pytorch.org/whl/cu117'
# github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }} name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
env: env:
@ -100,40 +86,38 @@ jobs:
id: run-pytest id: run-pytest
run: pytest run: pytest
- name: run invokeai-configure # - name: run invokeai-configure
id: run-preload-models # env:
env: # HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }} # run: >
run: > # invokeai-configure
invokeai-configure # --yes
--yes # --default_only
--default_only # --full-precision
--full-precision # # can't use fp16 weights without a GPU
# can't use fp16 weights without a GPU
- name: run invokeai # - name: run invokeai
id: run-invokeai # id: run-invokeai
env: # env:
# Set offline mode to make sure configure preloaded successfully. # # Set offline mode to make sure configure preloaded successfully.
HF_HUB_OFFLINE: 1 # HF_HUB_OFFLINE: 1
HF_DATASETS_OFFLINE: 1 # HF_DATASETS_OFFLINE: 1
TRANSFORMERS_OFFLINE: 1 # TRANSFORMERS_OFFLINE: 1
INVOKEAI_OUTDIR: ${{ github.workspace }}/results # INVOKEAI_OUTDIR: ${{ github.workspace }}/results
run: > # run: >
invokeai # invokeai
--no-patchmatch # --no-patchmatch
--no-nsfw_checker # --no-nsfw_checker
--precision=float32 # --precision=float32
--always_use_cpu # --always_use_cpu
--use_memory_db # --use_memory_db
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }} # --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }} # --from_file ${{ env.TEST_PROMPTS }}
- name: Archive results # - name: Archive results
id: archive-results # env:
env: # INVOKEAI_OUTDIR: ${{ github.workspace }}/results
INVOKEAI_OUTDIR: ${{ github.workspace }}/results # uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v3 # with:
with: # name: results
name: results # path: ${{ env.INVOKEAI_OUTDIR }}
path: ${{ env.INVOKEAI_OUTDIR }}

View File

@ -87,16 +87,16 @@ Prior to installing PyPatchMatch, you need to take the following steps:
sudo pacman -S --needed base-devel sudo pacman -S --needed base-devel
``` ```
2. Install `opencv`: 2. Install `opencv` and `blas`:
```sh ```sh
sudo pacman -S opencv sudo pacman -S opencv blas
``` ```
or for CUDA support or for CUDA support
```sh ```sh
sudo pacman -S opencv-cuda sudo pacman -S opencv-cuda blas
``` ```
3. Fix the naming of the `opencv` package configuration file: 3. Fix the naming of the `opencv` package configuration file:

View File

@ -38,6 +38,7 @@ echo https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist
echo. echo.
echo See %INSTRUCTIONS% for more details. echo See %INSTRUCTIONS% for more details.
echo. echo.
echo "For the best user experience we suggest enlarging or maximizing this window now."
pause pause
@rem ---------------------------- check Python version --------------- @rem ---------------------------- check Python version ---------------

View File

@ -26,6 +26,7 @@ done
if [ -z "$PYTHON" ]; then if [ -z "$PYTHON" ]; then
echo "A suitable Python interpreter could not be found" echo "A suitable Python interpreter could not be found"
echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help." echo "Please install Python $MINIMUM_PYTHON_VERSION or higher (maximum $MAXIMUM_PYTHON_VERSION) before running this script. See instructions at $INSTRUCTIONS for help."
echo "For the best user experience we suggest enlarging or maximizing this window now."
read -p "Press any key to exit" read -p "Press any key to exit"
exit -1 exit -1
fi fi

View File

@ -293,6 +293,8 @@ def introduction() -> None:
"3. Create initial configuration files.", "3. Create initial configuration files.",
"", "",
"[i]At any point you may interrupt this program and resume later.", "[i]At any point you may interrupt this program and resume later.",
"",
"[b]For the best user experience, please enlarge or maximize this window",
), ),
) )
) )

View File

@ -1,10 +1,11 @@
# InvokeAI nodes for ControlNet image preprocessors # Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import float from builtins import float, bool
import cv2
import numpy as np import numpy as np
from typing import Literal, Optional, Union, List from typing import Literal, Optional, Union, List, Dict
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
@ -29,8 +30,13 @@ from controlnet_aux import (
ContentShuffleDetector, ContentShuffleDetector,
ZoeDetector, ZoeDetector,
MediapipeFaceDetector, MediapipeFaceDetector,
SamDetector,
LeresDetector,
) )
from controlnet_aux.util import HWC3, ade_palette
from .image import ImageOutput, PILInvocationConfig from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [ CONTROLNET_DEFAULT_MODELS = [
@ -94,6 +100,10 @@ CONTROLNET_DEFAULT_MODELS = [
] ]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
# crop and fill options not ready yet
# CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])]
class ControlField(BaseModel): class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image") image: ImageField = Field(default=None, description="The control image")
@ -104,6 +114,9 @@ class ControlField(BaseModel):
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
# resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight") @validator("control_weight")
def abs_le_one(cls, v): def abs_le_one(cls, v):
"""validate that all abs(values) are <=1""" """validate that all abs(values) are <=1"""
@ -144,11 +157,11 @@ class ControlNetInvocation(BaseInvocation):
control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny",
description="control model used") description="control model used")
control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet")
# TODO: add support in backend core for begin_step_percent, end_step_percent, guess_mode
begin_step_percent: float = Field(default=0, ge=0, le=1, begin_step_percent: float = Field(default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)") description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(default=1, ge=0, le=1, end_step_percent: float = Field(default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)") description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used")
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
@ -166,7 +179,6 @@ class ControlNetInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> ControlOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput( return ControlOutput(
control=ControlField( control=ControlField(
image=self.image, image=self.image,
@ -174,10 +186,11 @@ class ControlNetInvocation(BaseInvocation):
control_weight=self.control_weight, control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
), ),
) )
# TODO: move image processors to separate file (image_analysis.py
class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Base class for invocations that preprocess images for ControlNet""" """Base class for invocations that preprocess images for ControlNet"""
@ -449,6 +462,104 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
# fmt: on # fmt: on
def run_processor(self, image): def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel
# so convert to RGB if needed
if image.mode == 'RGBA':
image = image.convert('RGB')
mediapipe_face_processor = MediapipeFaceDetector() mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image return processed_image
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies leres processing to image"""
# fmt: off
type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs
thr_a: float = Field(default=0, description="Leres parameter `thr_a`")
thr_b: float = Field(default=0, description="Leres parameter `thr_b`")
boost: bool = Field(default=False, description="Whether to use boost mode")
detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection")
image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image")
# fmt: on
def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(image,
thr_a=self.thr_a,
thr_b=self.thr_b,
boost=self.boost,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution)
return processed_image
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
# fmt: off
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
#res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile")
down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate")
# fmt: on
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(self,
np_img: np.ndarray,
res=512, # never used?
down_sampling_rate=1.0,
):
np_img = HWC3(np_img)
if down_sampling_rate < 1.1:
return np_img
H, W, C = np_img.shape
H = int(float(H) / float(down_sampling_rate))
W = int(float(W) / float(down_sampling_rate))
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample(np_img,
#res=self.tile_size,
down_sampling_rate=self.down_sampling_rate
)
processed_image = Image.fromarray(processed_np_image)
return processed_image
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
"""Applies segment anything processing to image"""
# fmt: off
type: Literal["segment_anything_processor"] = "segment_anything_processor"
# fmt: on
def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img)
return processed_image
class SamDetectorReproducibleColors(SamDetector):
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
# base class show_anns() method randomizes colors,
# which seems to also lead to non-reproducible image generation
# so using ADE20k color palette instead
def show_anns(self, anns: List[Dict]):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
h, w = anns[0]['segmentation'].shape
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
palette = ade_palette()
for i, ann in enumerate(sorted_anns):
m = ann['segmentation']
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
# doing modulo just in case number of annotated regions exceeds number of colors in palette
ann_color = palette[i % len(palette)]
img[:, :] = ann_color
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8)

View File

@ -23,7 +23,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
@ -59,31 +59,12 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
height=latents.size()[2] * 8, height=latents.size()[2] * 8,
) )
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
#fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[
tuple(list(SCHEDULER_MAP.keys())) tuple(list(SCHEDULER_MAP.keys()))
] ]
def get_scheduler( def get_scheduler(
context: InvocationContext, context: InvocationContext,
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
@ -105,62 +86,6 @@ def get_scheduler(
return scheduler return scheduler
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
generator = torch.Generator(device=use_device).manual_seed(seed)
x = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=use_device,
generator=generator,
).to(device)
# if self.perlin > 0.0:
# perlin_noise = self.get_perlin_noise(
# width // self.downsampling_factor, height // self.downsampling_factor
# )
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)
# Text to image # Text to image
class TextToLatentsInvocation(BaseInvocation): class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""
@ -287,19 +212,14 @@ class TextToLatentsInvocation(BaseInvocation):
control_height_resize = latents_shape[2] * 8 control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8 control_width_resize = latents_shape[3] * 8
if control_input is None: if control_input is None:
# print("control input is None")
control_list = None control_list = None
elif isinstance(control_input, list) and len(control_input) == 0: elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None control_list = None
elif isinstance(control_input, ControlField): elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input] control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField): elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input control_list = control_input
else: else:
# print("input control is unrecognized:", type(self.control))
control_list = None control_list = None
if (control_list is None): if (control_list is None):
control_data = None control_data = None
@ -341,12 +261,15 @@ class TextToLatentsInvocation(BaseInvocation):
# num_images_per_prompt=num_images_per_prompt, # num_images_per_prompt=num_images_per_prompt,
device=control_model.device, device=control_model.device,
dtype=control_model.dtype, dtype=control_model.dtype,
control_mode=control_info.control_mode,
) )
control_item = ControlNetData(model=control_model, control_item = ControlNetData(model=control_model,
image_tensor=control_image, image_tensor=control_image,
weight=control_info.control_weight, weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent, begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent) end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
)
control_data.append(control_item) control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData] # MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data return control_data

View File

@ -0,0 +1,134 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
import math
from typing import Literal
from pydantic import Field, validator
import torch
from invokeai.app.invocations.latent import LatentsField
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationConfig,
InvocationContext,
)
"""
Utilities
"""
def get_noise(
width: int,
height: int,
device: torch.device,
seed: int = 0,
latent_channels: int = 4,
downsampling_factor: int = 8,
use_cpu: bool = True,
perlin: float = 0.0,
):
"""Generate noise for a given image size."""
noise_device_type = "cpu" if (use_cpu or device.type == "mps") else device.type
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
generator = torch.Generator(device=noise_device_type).manual_seed(seed)
noise_tensor = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=noise_device_type,
generator=generator,
).to(device)
return noise_tensor
"""
Nodes
"""
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
# fmt: off
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
# fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(
ge=0,
le=SEED_MAX,
description="The seed to use",
default_factory=get_random_seed,
)
width: int = Field(
default=512,
multiple_of=8,
gt=0,
description="The width of the resulting noise",
)
height: int = Field(
default=512,
multiple_of=8,
gt=0,
description="The height of the resulting noise",
)
use_cpu: bool = Field(
default=True,
description="Use CPU for noise generation (for reproducible results across platforms)",
)
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
device=choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)

View File

@ -133,20 +133,19 @@ class StepParamEasingInvocation(BaseInvocation):
postlist = list(num_poststeps * [self.post_end_value]) postlist = list(num_poststeps * [self.post_end_value])
if log_diagnostics: if log_diagnostics:
logger = InvokeAILogger.getLogger(name="StepParamEasing") context.services.logger.debug("start_step: " + str(start_step))
logger.debug("start_step: " + str(start_step)) context.services.logger.debug("end_step: " + str(end_step))
logger.debug("end_step: " + str(end_step)) context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
logger.debug("num_easing_steps: " + str(num_easing_steps)) context.services.logger.debug("num_presteps: " + str(num_presteps))
logger.debug("num_presteps: " + str(num_presteps)) context.services.logger.debug("num_poststeps: " + str(num_poststeps))
logger.debug("num_poststeps: " + str(num_poststeps)) context.services.logger.debug("prelist size: " + str(len(prelist)))
logger.debug("prelist size: " + str(len(prelist))) context.services.logger.debug("postlist size: " + str(len(postlist)))
logger.debug("postlist size: " + str(len(postlist))) context.services.logger.debug("prelist: " + str(prelist))
logger.debug("prelist: " + str(prelist)) context.services.logger.debug("postlist: " + str(postlist))
logger.debug("postlist: " + str(postlist))
easing_class = EASING_FUNCTIONS_MAP[self.easing] easing_class = EASING_FUNCTIONS_MAP[self.easing]
if log_diagnostics: if log_diagnostics:
logger.debug("easing class: " + str(easing_class)) context.services.logger.debug("easing class: " + str(easing_class))
easing_list = list() easing_list = list()
if self.mirror: # "expected" mirroring if self.mirror: # "expected" mirroring
# if number of steps is even, squeeze duration down to (number_of_steps)/2 # if number of steps is even, squeeze duration down to (number_of_steps)/2
@ -156,7 +155,7 @@ class StepParamEasingInvocation(BaseInvocation):
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always # but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps/2.0)) base_easing_duration = int(np.ceil(num_easing_steps/2.0))
if log_diagnostics: logger.debug("base easing duration: " + str(base_easing_duration)) if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value, easing_function = easing_class(start=self.start_value,
end=self.end_value, end=self.end_value,
@ -166,14 +165,14 @@ class StepParamEasingInvocation(BaseInvocation):
easing_val = easing_function.ease(step_index) easing_val = easing_function.ease(step_index)
base_easing_vals.append(easing_val) base_easing_vals.append(easing_val)
if log_diagnostics: if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val)) context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
if even_num_steps: if even_num_steps:
mirror_easing_vals = list(reversed(base_easing_vals)) mirror_easing_vals = list(reversed(base_easing_vals))
else: else:
mirror_easing_vals = list(reversed(base_easing_vals[0:-1])) mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
if log_diagnostics: if log_diagnostics:
logger.debug("base easing vals: " + str(base_easing_vals)) context.services.logger.debug("base easing vals: " + str(base_easing_vals))
logger.debug("mirror easing vals: " + str(mirror_easing_vals)) context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
easing_list = base_easing_vals + mirror_easing_vals easing_list = base_easing_vals + mirror_easing_vals
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
@ -206,12 +205,12 @@ class StepParamEasingInvocation(BaseInvocation):
step_val = easing_function.ease(step_index) step_val = easing_function.ease(step_index)
easing_list.append(step_val) easing_list.append(step_val)
if log_diagnostics: if log_diagnostics:
logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val)) context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
if log_diagnostics: if log_diagnostics:
logger.debug("prelist size: " + str(len(prelist))) context.services.logger.debug("prelist size: " + str(len(prelist)))
logger.debug("easing_list size: " + str(len(easing_list))) context.services.logger.debug("easing_list size: " + str(len(easing_list)))
logger.debug("postlist size: " + str(len(postlist))) context.services.logger.debug("postlist size: " + str(len(postlist)))
param_list = prelist + easing_list + postlist param_list = prelist + easing_list + postlist

View File

@ -374,8 +374,10 @@ setting environment variables INVOKEAI_<setting>.
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance') tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths') root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths') autoimport_dir : Path = Field(default='autoimport/main', description='Path to a directory of models files to be imported on startup.', category='Paths')
autoconvert_dir : Path = Field(default=None, description='Deprecated configuration option.', category='Paths') lora_dir : Path = Field(default='autoimport/lora', description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
embedding_dir : Path = Field(default='autoimport/embedding', description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
controlnet_dir : Path = Field(default='autoimport/controlnet', description='Path to a directory of ControlNet embeddings to be imported on startup.', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths') conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths') models_dir : Path = Field(default='models', description='Path to the models directory', category='Paths')
legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths') legacy_conf_dir : Path = Field(default='configs/stable-diffusion', description='Path to directory of legacy checkpoint config files', category='Paths')

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph

View File

@ -7,8 +7,6 @@
# Coauthor: Kevin Turner http://github.com/keturn # Coauthor: Kevin Turner http://github.com/keturn
# #
import sys import sys
print("Loading Python libraries...\n",file=sys.stderr)
import argparse import argparse
import io import io
import os import os
@ -442,6 +440,26 @@ to allow InvokeAI to download restricted styles & subjects from the "Concept Lib
scroll_exit=True, scroll_exit=True,
) )
self.nextrely += 1 self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Directories containing textual inversion, controlnet and LoRA models (<tab> autocompletes, ctrl-N advances):",
editable=False,
color="CONTROL",
)
self.autoimport_dirs = {}
for description, config_name, path in autoimport_paths(old_opts):
self.autoimport_dirs[config_name] = self.add_widget_intelligent(
npyscreen.TitleFilename,
name=description+':',
value=str(path),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=32,
scroll_exit=True
)
self.nextrely += 1
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.TitleFixedText, npyscreen.TitleFixedText,
name="== LICENSE ==", name="== LICENSE ==",
@ -505,10 +523,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
bad_fields.append( bad_fields.append(
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory." f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
) )
# if not Path(opt.embedding_dir).parent.exists():
# bad_fields.append(
# f"The embedding directory does not seem to be valid. Please check that {str(Path(opt.embedding_dir).parent)} is an existing directory."
# )
if len(bad_fields) > 0: if len(bad_fields) > 0:
message = "The following problems were detected and must be corrected:\n" message = "The following problems were detected and must be corrected:\n"
for problem in bad_fields: for problem in bad_fields:
@ -528,12 +542,15 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
"max_loaded_models", "max_loaded_models",
"xformers_enabled", "xformers_enabled",
"always_use_cpu", "always_use_cpu",
# "embedding_dir",
# "lora_dir",
# "controlnet_dir",
]: ]:
setattr(new_opts, attr, getattr(self, attr).value) setattr(new_opts, attr, getattr(self, attr).value)
for attr in self.autoimport_dirs:
directory = Path(self.autoimport_dirs[attr].value)
if directory.is_relative_to(config.root_path):
directory = directory.relative_to(config.root_path)
setattr(new_opts, attr, directory)
new_opts.hf_token = self.hf_token.value new_opts.hf_token = self.hf_token.value
new_opts.license_acceptance = self.license_acceptance.value new_opts.license_acceptance = self.license_acceptance.value
new_opts.precision = PRECISION_CHOICES[self.precision.value[0]] new_opts.precision = PRECISION_CHOICES[self.precision.value[0]]
@ -595,22 +612,32 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
else [models[x].path or models[x].repo_id for x in installer.recommended_models()] else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
if program_opts.yes_to_all if program_opts.yes_to_all
else list(), else list(),
scan_directory=None, # scan_directory=None,
autoscan_on_startup=None, # autoscan_on_startup=None,
) )
# -------------------------------------
def autoimport_paths(config: InvokeAIAppConfig):
return [
('Checkpoints & diffusers models', 'autoimport_dir', config.root_path / config.autoimport_dir),
('LoRA/LyCORIS models', 'lora_dir', config.root_path / config.lora_dir),
('Controlnet models', 'controlnet_dir', config.root_path / config.controlnet_dir),
('Textual Inversion Embeddings', 'embedding_dir', config.root_path / config.embedding_dir),
]
# ------------------------------------- # -------------------------------------
def initialize_rootdir(root: Path, yes_to_all: bool = False): def initialize_rootdir(root: Path, yes_to_all: bool = False):
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **") logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
for name in ( for name in (
"models", "models",
"databases", "databases",
"autoimport",
"text-inversion-output", "text-inversion-output",
"text-inversion-training-data", "text-inversion-training-data",
"configs" "configs"
): ):
os.makedirs(os.path.join(root, name), exist_ok=True) os.makedirs(os.path.join(root, name), exist_ok=True)
for model_type in ModelType:
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True)
configs_src = Path(configs.__path__[0]) configs_src = Path(configs.__path__[0])
configs_dest = root / "configs" configs_dest = root / "configs"
@ -618,9 +645,8 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True) shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
dest = root / 'models' dest = root / 'models'
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]: for model_base in BaseModelType:
for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora, for model_type in ModelType:
ModelType.ControlNet,ModelType.TextualInversion]:
path = dest / model_base.value / model_type.value path = dest / model_base.value / model_type.value
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
path = dest / 'core' path = dest / 'core'
@ -632,8 +658,6 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
} }
) )
) )
# with open(root / 'invokeai.yaml','w') as f:
# f.write('#empty invokeai.yaml initialization file')
# ------------------------------------- # -------------------------------------
def run_console_ui( def run_console_ui(
@ -680,18 +704,6 @@ def write_opts(opts: Namespace, init_file: Path):
def default_output_dir() -> Path: def default_output_dir() -> Path:
return config.root_path / "outputs" return config.root_path / "outputs"
# # -------------------------------------
# def default_embedding_dir() -> Path:
# return config.root_path / "embeddings"
# # -------------------------------------
# def default_lora_dir() -> Path:
# return config.root_path / "loras"
# # -------------------------------------
# def default_controlnet_dir() -> Path:
# return config.root_path / "controlnets"
# ------------------------------------- # -------------------------------------
def write_default_options(program_opts: Namespace, initfile: Path): def write_default_options(program_opts: Namespace, initfile: Path):
opt = default_startup_options(initfile) opt = default_startup_options(initfile)

View File

@ -70,8 +70,8 @@ class ModelInstallList:
class InstallSelections(): class InstallSelections():
install_models: List[str]= field(default_factory=list) install_models: List[str]= field(default_factory=list)
remove_models: List[str]=field(default_factory=list) remove_models: List[str]=field(default_factory=list)
scan_directory: Path = None # scan_directory: Path = None
autoscan_on_startup: bool=False # autoscan_on_startup: bool=False
@dataclass @dataclass
class ModelLoadInfo(): class ModelLoadInfo():
@ -155,8 +155,6 @@ class ModelInstall(object):
def install(self, selections: InstallSelections): def install(self, selections: InstallSelections):
job = 1 job = 1
jobs = len(selections.remove_models) + len(selections.install_models) jobs = len(selections.remove_models) + len(selections.install_models)
if selections.scan_directory:
jobs += 1
# remove requested models # remove requested models
for key in selections.remove_models: for key in selections.remove_models:
@ -171,18 +169,8 @@ class ModelInstall(object):
self.heuristic_install(path) self.heuristic_install(path)
job += 1 job += 1
# import from the scan directory, if any
if path := selections.scan_directory:
logger.info(f'Scanning and importing models from directory {path} [{job}/{jobs}]')
self.heuristic_install(path)
self.mgr.commit() self.mgr.commit()
if selections.autoscan_on_startup and Path(selections.scan_directory).is_dir():
update_autoimport_dir(selections.scan_directory)
else:
update_autoimport_dir(None)
def heuristic_install(self, def heuristic_install(self,
model_path_id_or_url: Union[str,Path], model_path_id_or_url: Union[str,Path],
models_installed: Set[Path]=None)->Set[Path]: models_installed: Set[Path]=None)->Set[Path]:
@ -228,7 +216,7 @@ class ModelInstall(object):
# the model from being probed twice in the event that it has already been probed. # the model from being probed twice in the event that it has already been probed.
def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path: def _install_path(self, path: Path, info: ModelProbeInfo=None)->Path:
try: try:
logger.info(f'Probing {path}') # logger.debug(f'Probing {path}')
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
model_name = path.stem if info.format=='checkpoint' else path.name model_name = path.stem if info.format=='checkpoint' else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type): if self.mgr.model_exists(model_name, info.base_type, info.model_type):
@ -237,7 +225,7 @@ class ModelInstall(object):
self.mgr.add_model(model_name = model_name, self.mgr.add_model(model_name = model_name,
base_model = info.base_type, base_model = info.base_type,
model_type = info.model_type, model_type = info.model_type,
model_attributes = attributes model_attributes = attributes,
) )
except Exception as e: except Exception as e:
logger.warning(f'{str(e)} Skipping registration.') logger.warning(f'{str(e)} Skipping registration.')
@ -309,11 +297,11 @@ class ModelInstall(object):
return location.stem return location.stem
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict: def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
# convoluted way to retrieve the description from datasets model_name = path.name if path.is_dir() else path.stem
description = f'{info.base_type.value} {info.model_type.value} model' description = f'{info.base_type.value} {info.model_type.value} model {model_name}'
if key := self.reverse_paths.get(self.current_id): if key := self.reverse_paths.get(self.current_id):
if key in self.datasets: if key in self.datasets:
description = self.datasets[key]['description'] description = self.datasets[key].get('description') or description
rel_path = self.relative_to_root(path) rel_path = self.relative_to_root(path)
@ -395,23 +383,6 @@ class ModelInstall(object):
''' '''
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()} return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
def update_autoimport_dir(autodir: Path):
'''
Update the "autoimport_dir" option in invokeai.yaml
'''
with open('log.txt','a') as f:
print(f'autodir = {autodir}',file=f)
invokeai_config_path = config.init_file_path
conf = OmegaConf.load(invokeai_config_path)
conf.InvokeAI.Paths.autoimport_dir = str(autodir) if autodir else None
yaml = OmegaConf.to_yaml(conf)
tmpfile = invokeai_config_path.parent / "new_config.tmp"
with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(yaml)
tmpfile.replace(invokeai_config_path)
# ------------------------------------- # -------------------------------------
def yes_or_no(prompt: str, default_yes=True): def yes_or_no(prompt: str, default_yes=True):
default = "y" if default_yes else "n" default = "y" if default_yes else "n"

View File

@ -168,11 +168,27 @@ structure at initialization time by scanning the models directory. The
in-memory data structure can be resynchronized by calling in-memory data structure can be resynchronized by calling
`manager.scan_models_directory()`. `manager.scan_models_directory()`.
Files and folders placed inside the `autoimport_dir` (path defined in Files and folders placed inside the `autoimport` paths (paths
`invokeai.yaml`, defaulting to `ROOTDIR/autoimport` will also be defined in `invokeai.yaml`) will also be scanned for new models at
scanned for new models at initialization time and added to initialization time and added to `models.yaml`. Files will not be
`models.yaml`. Files will not be moved from this location but moved from this location but preserved in-place. These directories
preserved in-place. are:
configuration default description
------------- ------- -----------
autoimport_dir autoimport/main main models
lora_dir autoimport/lora LoRA/LyCORIS models
embedding_dir autoimport/embedding TI embeddings
controlnet_dir autoimport/controlnet ControlNet models
In actuality, models located in any of these directories are scanned
to determine their type, so it isn't strictly necessary to organize
the different types in this way. This entry in `invokeai.yaml` will
recursively scan all subdirectories within `autoimport`, scan models
files it finds, and import them if recognized.
Paths:
autoimport_dir: autoimport
A model can be manually added using `add_model()` using the model's A model can be manually added using `add_model()` using the model's
name, base model, type and a dict of model attributes. See name, base model, type and a dict of model attributes. See
@ -208,6 +224,7 @@ checkpoint or safetensors file.
The path points to a file or directory on disk. If a relative path, The path points to a file or directory on disk. If a relative path,
the root is the InvokeAI ROOTDIR. the root is the InvokeAI ROOTDIR.
""" """
from __future__ import annotations from __future__ import annotations
@ -566,7 +583,7 @@ class ModelManager(object):
model_config = model_class.create_config(**model_attributes) model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
if clobber or model_key not in self.models: if model_key in self.models and not clobber:
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"') raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
old_model = self.models.pop(model_key, None) old_model = self.models.pop(model_key, None)
@ -697,49 +714,74 @@ class ModelManager(object):
if model_path.is_relative_to(self.app_config.root_path): if model_path.is_relative_to(self.app_config.root_path):
model_path = model_path.relative_to(self.app_config.root_path) model_path = model_path.relative_to(self.app_config.root_path)
model_config: ModelConfigBase = model_class.probe_config(str(model_path)) try:
self.models[model_key] = model_config model_config: ModelConfigBase = model_class.probe_config(str(model_path))
new_models_found = True self.models[model_key] = model_config
new_models_found = True
except NotImplementedError as e:
self.logger.warning(e)
imported_models = self.autoimport() imported_models = self.autoimport()
if (new_models_found or imported_models) and self.config_path: if (new_models_found or imported_models) and self.config_path:
self.commit() self.commit()
def autoimport(self): def autoimport(self)->set[Path]:
''' '''
Scan the autoimport directory (if defined) and import new models, delete defunct models. Scan the autoimport directory (if defined) and import new models, delete defunct models.
''' '''
# avoid circular import # avoid circular import
from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
installer = ModelInstall(config = self.app_config, installer = ModelInstall(config = self.app_config,
model_manager = self) model_manager = self,
prediction_type_helper = ask_user_for_prediction_type,
)
installed = set() installed = set()
if not self.app_config.autoimport_dir:
return installed
autodir = self.app_config.root_path / self.app_config.autoimport_dir
if not (autodir and autodir.exists()):
return installed
known_paths = {(self.app_config.root_path / x['path']).resolve() for x in self.list_models()}
scanned_dirs = set() scanned_dirs = set()
for root, dirs, files in os.walk(autodir):
for d in dirs:
path = Path(root) / d
if path in known_paths:
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
installed.update(installer.heuristic_install(path))
scanned_dirs.add(path)
for f in files: config = self.app_config
path = Path(root) / f known_paths = {(self.app_config.root_path / x['path']) for x in self.list_models()}
if path in known_paths or path.parent in scanned_dirs:
continue for autodir in [config.autoimport_dir,
if path.suffix in {'.ckpt','.bin','.pth','.safetensors'}: config.lora_dir,
installed.update(installer.heuristic_install(path)) config.embedding_dir,
config.controlnet_dir]:
if autodir is None:
continue
self.logger.info(f'Scanning {autodir} for models to import')
autodir = self.app_config.root_path / autodir
if not autodir.exists():
continue
items_scanned = 0
new_models_found = set()
for root, dirs, files in os.walk(autodir):
items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in known_paths or path.parent in scanned_dirs:
scanned_dirs.add(path)
continue
if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]):
new_models_found.update(installer.heuristic_install(path))
scanned_dirs.add(path)
for f in files:
path = Path(root) / f
if path in known_paths or path.parent in scanned_dirs:
continue
if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}:
new_models_found.update(installer.heuristic_install(path))
self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models')
installed.update(new_models_found)
return installed return installed
def heuristic_import(self, def heuristic_import(self,

View File

@ -22,7 +22,7 @@ class ModelProbeInfo(object):
variant_type: ModelVariantType variant_type: ModelVariantType
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType
upcast_attention: bool upcast_attention: bool
format: Literal['diffusers','checkpoint'] format: Literal['diffusers','checkpoint', 'lycoris']
image_size: int image_size: int
class ProbeBase(object): class ProbeBase(object):
@ -75,22 +75,23 @@ class ModelProbe(object):
between V2-Base and V2-768 SD models. between V2-Base and V2-768 SD models.
''' '''
if model_path: if model_path:
format = 'diffusers' if model_path.is_dir() else 'checkpoint' format_type = 'diffusers' if model_path.is_dir() else 'checkpoint'
else: else:
format = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint'
model_info = None model_info = None
try: try:
model_type = cls.get_model_type_from_folder(model_path, model) \ model_type = cls.get_model_type_from_folder(model_path, model) \
if format == 'diffusers' \ if format_type == 'diffusers' \
else cls.get_model_type_from_checkpoint(model_path, model) else cls.get_model_type_from_checkpoint(model_path, model)
probe_class = cls.PROBES[format].get(model_type) probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class: if not probe_class:
return None return None
probe = probe_class(model_path, model, prediction_type_helper) probe = probe_class(model_path, model, prediction_type_helper)
base_type = probe.get_base_type() base_type = probe.get_base_type()
variant_type = probe.get_variant_type() variant_type = probe.get_variant_type()
prediction_type = probe.get_scheduler_prediction_type() prediction_type = probe.get_scheduler_prediction_type()
format = probe.get_format()
model_info = ModelProbeInfo( model_info = ModelProbeInfo(
model_type = model_type, model_type = model_type,
base_type = base_type, base_type = base_type,
@ -116,10 +117,10 @@ class ModelProbe(object):
if model_path.name == "learned_embeds.bin": if model_path.name == "learned_embeds.bin":
return ModelType.TextualInversion return ModelType.TextualInversion
checkpoint = checkpoint or read_checkpoint_meta(model_path, scan=True) ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
checkpoint = checkpoint.get("state_dict", checkpoint) ckpt = ckpt.get("state_dict", ckpt)
for key in checkpoint.keys(): for key in ckpt.keys():
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
@ -133,7 +134,7 @@ class ModelProbe(object):
else: else:
# diffusers-ti # diffusers-ti
if len(checkpoint) < 10 and all(isinstance(v, torch.Tensor) for v in checkpoint.values()): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion return ModelType.TextualInversion
raise ValueError("Unable to determine model type") raise ValueError("Unable to determine model type")
@ -201,6 +202,9 @@ class ProbeBase(object):
def get_scheduler_prediction_type(self)->SchedulerPredictionType: def get_scheduler_prediction_type(self)->SchedulerPredictionType:
pass pass
def get_format(self)->str:
pass
class CheckpointProbeBase(ProbeBase): class CheckpointProbeBase(ProbeBase):
def __init__(self, def __init__(self,
checkpoint_path: Path, checkpoint_path: Path,
@ -214,6 +218,9 @@ class CheckpointProbeBase(ProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
pass pass
def get_format(self)->str:
return 'checkpoint'
def get_variant_type(self)-> ModelVariantType: def get_variant_type(self)-> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint) model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint)
if model_type != ModelType.Main: if model_type != ModelType.Main:
@ -255,7 +262,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return SchedulerPredictionType.Epsilon return SchedulerPredictionType.Epsilon
elif checkpoint["global_step"] == 110000: elif checkpoint["global_step"] == 110000:
return SchedulerPredictionType.VPrediction return SchedulerPredictionType.VPrediction
if self.checkpoint_path and self.helper: if self.checkpoint_path and self.helper \
and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed
return self.helper(self.checkpoint_path) return self.helper(self.checkpoint_path)
else: else:
return None return None
@ -266,6 +274,9 @@ class VaeCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase): class LoRACheckpointProbe(CheckpointProbeBase):
def get_format(self)->str:
return 'lycoris'
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint checkpoint = self.checkpoint
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
@ -285,6 +296,9 @@ class LoRACheckpointProbe(CheckpointProbeBase):
return None return None
class TextualInversionCheckpointProbe(CheckpointProbeBase): class TextualInversionCheckpointProbe(CheckpointProbeBase):
def get_format(self)->str:
return None
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
checkpoint = self.checkpoint checkpoint = self.checkpoint
if 'string_to_token' in checkpoint: if 'string_to_token' in checkpoint:
@ -331,6 +345,9 @@ class FolderProbeBase(ProbeBase):
def get_variant_type(self)->ModelVariantType: def get_variant_type(self)->ModelVariantType:
return ModelVariantType.Normal return ModelVariantType.Normal
def get_format(self)->str:
return 'diffusers'
class PipelineFolderProbe(FolderProbeBase): class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
if self.model: if self.model:
@ -386,6 +403,9 @@ class VaeFolderProbe(FolderProbeBase):
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
class TextualInversionFolderProbe(FolderProbeBase): class TextualInversionFolderProbe(FolderProbeBase):
def get_format(self)->str:
return None
def get_base_type(self)->BaseModelType: def get_base_type(self)->BaseModelType:
path = self.folder_path / 'learned_embeds.bin' path = self.folder_path / 'learned_embeds.bin'
if not path.exists(): if not path.exists():

View File

@ -397,7 +397,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
checkpoint = safetensors.torch.load_file(path, device="cpu") checkpoint = safetensors.torch.load_file(path, device="cpu")
else: else:
if scan: if scan:
scan_result = scan_file_path(checkpoint) scan_result = scan_file_path(path)
if scan_result.infected_files != 0: if scan_result.infected_files != 0:
raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.") raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location=torch.device("meta")) checkpoint = torch.load(path, map_location=torch.device("meta"))

View File

@ -69,7 +69,7 @@ class StableDiffusion1Model(DiffusersModel):
in_channels = unet_config['in_channels'] in_channels = unet_config['in_channels']
else: else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
else: else:
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}") raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
@ -259,8 +259,8 @@ def _convert_ckpt_and_cache(
""" """
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_dir / model_config.path weights = app_config.root_path / model_config.path
config_file = app_config.root_dir / model_config.config config_file = app_config.root_path / model_config.config
output_path = Path(output_path) output_path = Path(output_path)
# return cached version if it exists # return cached version if it exists

View File

@ -215,10 +215,12 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
@dataclass @dataclass
class ControlNetData: class ControlNetData:
model: ControlNetModel = Field(default=None) model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor= Field(default=None) image_tensor: torch.Tensor = Field(default=None)
weight: Union[float, List[float]]= Field(default=1.0) weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
@dataclass(frozen=True) @dataclass(frozen=True)
class ConditioningData: class ConditioningData:
@ -599,48 +601,68 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# TODO: should this scaling happen here or inside self._unet_forward? # TODO: should this scaling happen here or inside self._unet_forward?
# i.e. before or after passing it to InvokeAIDiffuserComponent # i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep) unet_latent_input = self.scheduler.scale_model_input(latents, timestep)
# default is no controlnet, so set controlnet processing output to None # default is no controlnet, so set controlnet processing output to None
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
if control_data is not None: if control_data is not None:
# FIXME: make sure guidance_scale < 1.0 is handled correctly if doing per-step guidance setting
# if conditioning_data.guidance_scale > 1.0:
if conditioning_data.guidance_scale is not None:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
latent_control_input = torch.cat([latent_model_input] * 2)
else:
latent_control_input = latent_model_input
# control_data should be type List[ControlNetData] # control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list) # this loop covers both ControlNet (one ControlNetData in list)
# and MultiControlNet (multiple ControlNetData in list) # and MultiControlNet (multiple ControlNetData in list)
for i, control_datum in enumerate(control_data): for i, control_datum in enumerate(control_data):
# print("controlnet", i, "==>", type(control_datum)) control_mode = control_datum.control_mode
# soft_injection and cfg_injection are the two ControlNet control_mode booleans
# that are combined at higher level to make control_mode enum
# soft_injection determines whether to do per-layer re-weighting adjustment (if True)
# or default weighting (if False)
soft_injection = (control_mode == "more_prompt" or control_mode == "more_control")
# cfg_injection = determines whether to apply ControlNet to only the conditional (if True)
# or the default both conditional and unconditional (if False)
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) first_control_step = math.floor(control_datum.begin_step_percent * total_step_count)
last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) last_control_step = math.ceil(control_datum.end_step_percent * total_step_count)
# only apply controlnet if current step is within the controlnet's begin/end step range # only apply controlnet if current step is within the controlnet's begin/end step range
if step_index >= first_control_step and step_index <= last_control_step: if step_index >= first_control_step and step_index <= last_control_step:
# print("running controlnet", i, "for step", step_index)
if cfg_injection:
control_latent_input = unet_latent_input
else:
# expand the latents input to control model if doing classifier free guidance
# (which I think for now is always true, there is conditional elsewhere that stops execution if
# classifier_free_guidance is <= 1.0 ?)
control_latent_input = torch.cat([unet_latent_input] * 2)
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings])
else:
encoder_hidden_states = torch.cat([conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings])
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
controlnet_weight = control_datum.weight[step_index] controlnet_weight = control_datum.weight[step_index]
else: else:
# if controlnet has a single weight, use it for all steps # if controlnet has a single weight, use it for all steps
controlnet_weight = control_datum.weight controlnet_weight = control_datum.weight
# controlnet(s) inference
down_samples, mid_sample = control_datum.model( down_samples, mid_sample = control_datum.model(
sample=latent_control_input, sample=control_latent_input,
timestep=timestep, timestep=timestep,
encoder_hidden_states=torch.cat([conditioning_data.unconditioned_embeddings, encoder_hidden_states=encoder_hidden_states,
conditioning_data.text_embeddings]),
controlnet_cond=control_datum.image_tensor, controlnet_cond=control_datum.image_tensor,
conditioning_scale=controlnet_weight, conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale
# cross_attention_kwargs, guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
guess_mode=False,
return_dict=False, return_dict=False,
) )
if cfg_injection:
# Inferred ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
if down_block_res_samples is None and mid_block_res_sample is None: if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else: else:
@ -653,11 +675,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# predict the noise residual # predict the noise residual
noise_pred = self.invokeai_diffuser.do_diffusion_step( noise_pred = self.invokeai_diffuser.do_diffusion_step(
latent_model_input, x=unet_latent_input,
t, sigma=t,
conditioning_data.unconditioned_embeddings, unconditioning=conditioning_data.unconditioned_embeddings,
conditioning_data.text_embeddings, conditioning=conditioning_data.text_embeddings,
conditioning_data.guidance_scale, unconditional_guidance_scale=conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
down_block_additional_residuals=down_block_res_samples, # from controlnet(s) down_block_additional_residuals=down_block_res_samples, # from controlnet(s)
@ -962,6 +984,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
device="cuda", device="cuda",
dtype=torch.float16, dtype=torch.float16,
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
control_mode="balanced"
): ):
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
@ -992,6 +1015,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
repeat_by = num_images_per_prompt repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0) image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance: cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
image = torch.cat([image] * 2) image = torch.cat([image] * 2)
return image return image

View File

@ -131,7 +131,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
window_width=window_width, window_width=window_width,
exclude = self.starter_models exclude = self.starter_models
) )
self.pipeline_models['autoload_pending'] = True # self.pipeline_models['autoload_pending'] = True
bottom_of_table = max(bottom_of_table,self.nextrely) bottom_of_table = max(bottom_of_table,self.nextrely)
self.nextrely = top_of_table self.nextrely = top_of_table
@ -316,31 +316,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
**kwargs, **kwargs,
) )
label = "Directory to scan for models to automatically import (<tab> autocompletes):"
self.nextrely += 1
widgets.update(
autoload_directory = self.add_widget_intelligent(
FileBox,
max_height=3,
name=label,
value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else None,
select_dir=True,
must_exist=True,
use_two_lines=False,
labelColor="DANGER",
begin_entry_at=len(label)+1,
scroll_exit=True,
)
)
widgets.update(
autoscan_on_startup = self.add_widget_intelligent(
npyscreen.Checkbox,
name="Scan and import from this directory each time InvokeAI starts",
value=config.autoimport_dir is not None,
relx=4,
scroll_exit=True,
)
)
return widgets return widgets
def resize(self): def resize(self):
@ -501,8 +476,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
# rebuild the form, saving and restoring some of the fields that need to be preserved. # rebuild the form, saving and restoring some of the fields that need to be preserved.
saved_messages = self.monitor.entry_widget.values saved_messages = self.monitor.entry_widget.values
autoload_dir = str(config.root_path / self.pipeline_models['autoload_directory'].value) # autoload_dir = str(config.root_path / self.pipeline_models['autoload_directory'].value)
autoscan = self.pipeline_models['autoscan_on_startup'].value # autoscan = self.pipeline_models['autoscan_on_startup'].value
app.main_form = app.addForm( app.main_form = app.addForm(
"MAIN", addModelsForm, name="Install Stable Diffusion Models", multipage=self.multipage, "MAIN", addModelsForm, name="Install Stable Diffusion Models", multipage=self.multipage,
@ -511,8 +486,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
app.main_form.monitor.entry_widget.values = saved_messages app.main_form.monitor.entry_widget.values = saved_messages
app.main_form.monitor.entry_widget.buffer([''],scroll_end=True) app.main_form.monitor.entry_widget.buffer([''],scroll_end=True)
app.main_form.pipeline_models['autoload_directory'].value = autoload_dir # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir
app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan
def marshall_arguments(self): def marshall_arguments(self):
""" """
@ -546,17 +521,17 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
selections.install_models.extend(downloads.value.split()) selections.install_models.extend(downloads.value.split())
# load directory and whether to scan on startup # load directory and whether to scan on startup
if self.parentApp.autoload_pending: # if self.parentApp.autoload_pending:
selections.scan_directory = str(config.root_path / self.pipeline_models['autoload_directory'].value) # selections.scan_directory = str(config.root_path / self.pipeline_models['autoload_directory'].value)
self.parentApp.autoload_pending = False # self.parentApp.autoload_pending = False
selections.autoscan_on_startup = self.pipeline_models['autoscan_on_startup'].value # selections.autoscan_on_startup = self.pipeline_models['autoscan_on_startup'].value
class AddModelApplication(npyscreen.NPSAppManaged): class AddModelApplication(npyscreen.NPSAppManaged):
def __init__(self,opt): def __init__(self,opt):
super().__init__() super().__init__()
self.program_opts = opt self.program_opts = opt
self.user_cancelled = False self.user_cancelled = False
self.autoload_pending = True # self.autoload_pending = True
self.install_selections = InstallSelections() self.install_selections = InstallSelections()
def onStart(self): def onStart(self):
@ -578,20 +553,20 @@ class StderrToMessage():
# -------------------------------------------------------- # --------------------------------------------------------
def ask_user_for_prediction_type(model_path: Path, def ask_user_for_prediction_type(model_path: Path,
tui_conn: Connection=None tui_conn: Connection=None
)->Path: )->SchedulerPredictionType:
if tui_conn: if tui_conn:
logger.debug('Waiting for user response...') logger.debug('Waiting for user response...')
return _ask_user_for_pt_tui(model_path, tui_conn) return _ask_user_for_pt_tui(model_path, tui_conn)
else: else:
return _ask_user_for_pt_cmdline(model_path) return _ask_user_for_pt_cmdline(model_path)
def _ask_user_for_pt_cmdline(model_path): def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType:
choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None]
print( print(
f""" f"""
Please select the type of the V2 checkpoint named {model_path.name}: Please select the type of the V2 checkpoint named {model_path.name}:
[1] A Stable Diffusion v2.x base model (512 pixels; there should be no 'parameterization:' line in its yaml file) [1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base)
[2] A Stable Diffusion v2.x v-predictive model (768 pixels; look for a 'parameterization: "v"' line in its yaml file) [2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768)
[3] Skip this model and come back later. [3] Skip this model and come back later.
""" """
) )
@ -608,7 +583,7 @@ Please select the type of the V2 checkpoint named {model_path.name}:
return return
return choice return choice
def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->Path: def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType:
try: try:
tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8')) tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8'))
# note that we don't do any status checking here # note that we don't do any status checking here
@ -810,13 +785,15 @@ def main():
logger.error( logger.error(
"Insufficient vertical space for the interface. Please make your window taller and try again" "Insufficient vertical space for the interface. Please make your window taller and try again"
) )
elif str(e).startswith("addwstr"): input('Press any key to continue...')
except Exception as e:
if str(e).startswith("addwstr"):
logger.error( logger.error(
"Insufficient horizontal space for the interface. Please make your window wider and try again." "Insufficient horizontal space for the interface. Please make your window wider and try again."
) )
except Exception as e: else:
print(f'An exception has occurred: {str(e)} Details:') print(f'An exception has occurred: {str(e)} Details:')
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
input('Press any key to continue...') input('Press any key to continue...')

View File

@ -42,6 +42,18 @@ def set_terminal_size(columns: int, lines: int, launch_command: str=None):
elif OS in ["Darwin", "Linux"]: elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width,height) _set_terminal_size_unix(width,height)
# check whether it worked....
ts = get_terminal_size()
pause = False
if ts.columns < columns:
print('\033[1mThis window is too narrow for the user interface. Please make it wider.\033[0m')
pause = True
if ts.lines < lines:
print('\033[1mThis window is too short for the user interface. Please make it taller.\033[0m')
pause = True
if pause:
input('Press any key to continue..')
def _set_terminal_size_powershell(width: int, height: int): def _set_terminal_size_powershell(width: int, height: int):
script=f''' script=f'''
$pshost = get-host $pshost = get-host

View File

@ -0,0 +1,14 @@
import react from '@vitejs/plugin-react-swc';
import { visualizer } from 'rollup-plugin-visualizer';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import { nodePolyfills } from 'vite-plugin-node-polyfills';
export const commonPlugins: UserConfig['plugins'] = [
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
nodePolyfills(),
];

View File

@ -1,17 +1,9 @@
import react from '@vitejs/plugin-react-swc'; import { UserConfig } from 'vite';
import { visualizer } from 'rollup-plugin-visualizer'; import { commonPlugins } from './common';
import { PluginOption, UserConfig } from 'vite';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
export const appConfig: UserConfig = { export const appConfig: UserConfig = {
base: './', base: './',
plugins: [ plugins: [...commonPlugins],
react(),
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
],
build: { build: {
chunkSizeWarningLimit: 1500, chunkSizeWarningLimit: 1500,
}, },

View File

@ -1,19 +1,13 @@
import react from '@vitejs/plugin-react-swc';
import path from 'path'; import path from 'path';
import { visualizer } from 'rollup-plugin-visualizer'; import { UserConfig } from 'vite';
import { PluginOption, UserConfig } from 'vite';
import dts from 'vite-plugin-dts'; import dts from 'vite-plugin-dts';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js'; import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
import { commonPlugins } from './common';
export const packageConfig: UserConfig = { export const packageConfig: UserConfig = {
base: './', base: './',
plugins: [ plugins: [
react(), ...commonPlugins,
eslint(),
tsconfigPaths(),
visualizer() as unknown as PluginOption,
dts({ dts({
insertTypesEntry: true, insertTypesEntry: true,
}), }),

View File

@ -23,7 +23,7 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"", "dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"", "dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build", "build": "yarn run lint && vite build",
"typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/schema.d.ts -t", "typegen": "npx openapi-typescript http://localhost:9090/openapi.json --output src/services/api/schema.d.ts -t",
"preview": "vite preview", "preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx", "lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .", "lint:eslint": "eslint --max-warnings=0 .",
@ -53,36 +53,38 @@
] ]
}, },
"dependencies": { "dependencies": {
"@apidevtools/swagger-parser": "^10.1.0",
"@chakra-ui/anatomy": "^2.1.1", "@chakra-ui/anatomy": "^2.1.1",
"@chakra-ui/icons": "^2.0.19", "@chakra-ui/icons": "^2.0.19",
"@chakra-ui/react": "^2.6.0", "@chakra-ui/react": "^2.7.1",
"@chakra-ui/styled-system": "^2.9.0", "@chakra-ui/styled-system": "^2.9.1",
"@chakra-ui/theme-tools": "^2.0.16", "@chakra-ui/theme-tools": "^2.0.18",
"@dagrejs/graphlib": "^2.1.12", "@dagrejs/graphlib": "^2.1.13",
"@dnd-kit/core": "^6.0.8", "@dnd-kit/core": "^6.0.8",
"@dnd-kit/modifiers": "^6.0.1", "@dnd-kit/modifiers": "^6.0.1",
"@emotion/react": "^11.11.1", "@emotion/react": "^11.11.1",
"@emotion/styled": "^11.10.6", "@emotion/styled": "^11.11.0",
"@floating-ui/react-dom": "^2.0.0", "@floating-ui/react-dom": "^2.0.1",
"@fontsource/inter": "^4.5.15", "@fontsource-variable/inter": "^5.0.3",
"@mantine/core": "^6.0.13", "@fontsource/inter": "^5.0.3",
"@mantine/hooks": "^6.0.13", "@mantine/core": "^6.0.14",
"@mantine/hooks": "^6.0.14",
"@reduxjs/toolkit": "^1.9.5", "@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5", "@roarr/browser-log-writer": "^1.1.5",
"chakra-ui-contextmenu": "^1.0.5", "chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3", "dateformat": "^5.0.3",
"downshift": "^7.6.0", "downshift": "^7.6.0",
"formik": "^2.2.9", "formik": "^2.4.2",
"framer-motion": "^10.12.4", "framer-motion": "^10.12.17",
"fuse.js": "^6.6.2", "fuse.js": "^6.6.2",
"i18next": "^22.4.15", "i18next": "^23.2.3",
"i18next-browser-languagedetector": "^7.0.1", "i18next-browser-languagedetector": "^7.0.2",
"i18next-http-backend": "^2.2.0", "i18next-http-backend": "^2.2.1",
"konva": "^9.0.1", "konva": "^9.2.0",
"lodash-es": "^4.17.21", "lodash-es": "^4.17.21",
"nanostores": "^0.9.2", "nanostores": "^0.9.2",
"openapi-fetch": "^0.4.0", "openapi-fetch": "^0.4.0",
"overlayscrollbars": "^2.1.1", "overlayscrollbars": "^2.2.0",
"overlayscrollbars-react": "^0.5.0", "overlayscrollbars-react": "^0.5.0",
"patch-package": "^7.0.0", "patch-package": "^7.0.0",
"query-string": "^8.1.0", "query-string": "^8.1.0",
@ -92,21 +94,21 @@
"react-dom": "^18.2.0", "react-dom": "^18.2.0",
"react-dropzone": "^14.2.3", "react-dropzone": "^14.2.3",
"react-hotkeys-hook": "4.4.0", "react-hotkeys-hook": "4.4.0",
"react-i18next": "^12.2.2", "react-i18next": "^13.0.1",
"react-icons": "^4.9.0", "react-icons": "^4.10.1",
"react-konva": "^18.2.7", "react-konva": "^18.2.10",
"react-redux": "^8.0.5", "react-redux": "^8.1.1",
"react-resizable-panels": "^0.0.42", "react-resizable-panels": "^0.0.52",
"react-use": "^17.4.0", "react-use": "^17.4.0",
"react-virtuoso": "^4.3.5", "react-virtuoso": "^4.3.11",
"react-zoom-pan-pinch": "^3.0.7", "react-zoom-pan-pinch": "^3.0.8",
"reactflow": "^11.7.0", "reactflow": "^11.7.4",
"redux-dynamic-middlewares": "^2.2.0", "redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^3.3.1", "redux-remember": "^3.3.1",
"roarr": "^7.15.0", "roarr": "^7.15.0",
"serialize-error": "^11.0.0", "serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0", "socket.io-client": "^4.7.0",
"use-image": "^1.1.0", "use-image": "^1.1.1",
"uuid": "^9.0.0", "uuid": "^9.0.0",
"zod": "^3.21.4" "zod": "^3.21.4"
}, },
@ -117,22 +119,22 @@
"ts-toolbelt": "^9.6.0" "ts-toolbelt": "^9.6.0"
}, },
"devDependencies": { "devDependencies": {
"@chakra-ui/cli": "^2.4.0", "@chakra-ui/cli": "^2.4.1",
"@types/dateformat": "^5.0.0", "@types/dateformat": "^5.0.0",
"@types/lodash-es": "^4.14.194", "@types/lodash-es": "^4.14.194",
"@types/node": "^18.16.2", "@types/node": "^20.3.1",
"@types/react": "^18.2.0", "@types/react": "^18.2.14",
"@types/react-dom": "^18.2.1", "@types/react-dom": "^18.2.6",
"@types/react-redux": "^7.1.25", "@types/react-redux": "^7.1.25",
"@types/react-transition-group": "^4.4.5", "@types/react-transition-group": "^4.4.6",
"@types/uuid": "^9.0.0", "@types/uuid": "^9.0.2",
"@typescript-eslint/eslint-plugin": "^5.59.1", "@typescript-eslint/eslint-plugin": "^5.60.0",
"@typescript-eslint/parser": "^5.59.1", "@typescript-eslint/parser": "^5.60.0",
"@vitejs/plugin-react-swc": "^3.3.0", "@vitejs/plugin-react-swc": "^3.3.2",
"axios": "^1.4.0", "axios": "^1.4.0",
"babel-plugin-transform-imports": "^2.0.0", "babel-plugin-transform-imports": "^2.0.0",
"concurrently": "^8.0.1", "concurrently": "^8.2.0",
"eslint": "^8.39.0", "eslint": "^8.43.0",
"eslint-config-prettier": "^8.8.0", "eslint-config-prettier": "^8.8.0",
"eslint-plugin-prettier": "^4.2.1", "eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react": "^7.32.2", "eslint-plugin-react": "^7.32.2",
@ -140,19 +142,20 @@
"form-data": "^4.0.0", "form-data": "^4.0.0",
"husky": "^8.0.3", "husky": "^8.0.3",
"lint-staged": "^13.2.2", "lint-staged": "^13.2.2",
"madge": "^6.0.0", "madge": "^6.1.0",
"openapi-types": "^12.1.0", "openapi-types": "^12.1.3",
"openapi-typescript": "^6.2.8", "openapi-typescript": "^6.2.8",
"openapi-typescript-codegen": "^0.24.0", "openapi-typescript-codegen": "^0.24.0",
"postinstall-postinstall": "^2.1.0", "postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.8", "prettier": "^2.8.8",
"rollup-plugin-visualizer": "^5.9.0", "rollup-plugin-visualizer": "^5.9.2",
"terser": "^5.17.1", "terser": "^5.18.1",
"ts-toolbelt": "^9.6.0", "ts-toolbelt": "^9.6.0",
"vite": "^4.3.3", "vite": "^4.3.9",
"vite-plugin-css-injected-by-js": "^3.1.1", "vite-plugin-css-injected-by-js": "^3.1.1",
"vite-plugin-dts": "^2.3.0", "vite-plugin-dts": "^2.3.0",
"vite-plugin-eslint": "^1.8.1", "vite-plugin-eslint": "^1.8.1",
"vite-plugin-node-polyfills": "^0.9.0",
"vite-tsconfig-paths": "^4.2.0", "vite-tsconfig-paths": "^4.2.0",
"yarn": "^1.22.19" "yarn": "^1.22.19"
} }

View File

@ -1,14 +0,0 @@
diff --git a/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js b/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js
index 937cf0d..7dcc0c0 100644
--- a/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js
+++ b/node_modules/@chakra-ui/cli/dist/scripts/read-theme-file.worker.js
@@ -50,7 +50,8 @@ async function readTheme(themeFilePath) {
project: tsConfig.configFileAbsolutePath,
compilerOptions: {
module: "CommonJS",
- esModuleInterop: true
+ esModuleInterop: true,
+ jsx: 'react'
},
transpileOnly: true,
swc: true

View File

@ -524,7 +524,8 @@
"initialImage": "Initial Image", "initialImage": "Initial Image",
"showOptionsPanel": "Show Options Panel", "showOptionsPanel": "Show Options Panel",
"hidePreview": "Hide Preview", "hidePreview": "Hide Preview",
"showPreview": "Show Preview" "showPreview": "Show Preview",
"controlNetControlMode": "Control Mode"
}, },
"settings": { "settings": {
"models": "Models", "models": "Models",

View File

@ -48,7 +48,7 @@ const App = ({
const isApplicationReady = useIsApplicationReady(); const isApplicationReady = useIsApplicationReady();
const { data: pipelineModels } = useListModelsQuery({ const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline', model_type: 'main',
}); });
const { data: controlnetModels } = useListModelsQuery({ const { data: controlnetModels } = useListModelsQuery({
model_type: 'controlnet', model_type: 'controlnet',

View File

@ -14,7 +14,7 @@ import { invokeAIThemeColors } from 'theme/colors/invokeAI';
import { lightThemeColors } from 'theme/colors/lightTheme'; import { lightThemeColors } from 'theme/colors/lightTheme';
import { oceanBlueColors } from 'theme/colors/oceanBlue'; import { oceanBlueColors } from 'theme/colors/oceanBlue';
import '@fontsource/inter/variable.css'; import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core'; import { MantineProvider } from '@mantine/core';
import { mantineTheme } from 'mantine-theme/theme'; import { mantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css'; import 'overlayscrollbars/overlayscrollbars.css';

View File

@ -15,7 +15,7 @@ import { ImageDTO } from 'services/api/types';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice'; import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { nodesSelecter } from 'features/nodes/store/nodesSlice'; import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { some } from 'lodash-es'; import { some } from 'lodash-es';
@ -30,7 +30,7 @@ export const selectImageUsage = createSelector(
[ [
generationSelector, generationSelector,
canvasSelector, canvasSelector,
nodesSelecter, nodesSelector,
controlNetSelector, controlNetSelector,
(state: RootState, image_name?: string) => image_name, (state: RootState, image_name?: string) => image_name,
], ],

View File

@ -1,6 +1,7 @@
import { AnyAction } from '@reduxjs/toolkit'; import { AnyAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions'; import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { forEach } from 'lodash-es'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { Graph } from 'services/api/types'; import { Graph } from 'services/api/types';
export const actionSanitizer = <A extends AnyAction>(action: A): A => { export const actionSanitizer = <A extends AnyAction>(action: A): A => {
@ -8,17 +9,6 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
if (action.payload.nodes) { if (action.payload.nodes) {
const sanitizedNodes: Graph['nodes'] = {}; const sanitizedNodes: Graph['nodes'] = {};
// Sanitize nodes as needed
forEach(action.payload.nodes, (node, key) => {
// Don't log the whole freaking dataURL
if (node.type === 'dataURL_image') {
const { dataURL, ...rest } = node;
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
} else {
sanitizedNodes[key] = { ...node };
}
});
return { return {
...action, ...action,
payload: { ...action.payload, nodes: sanitizedNodes }, payload: { ...action.payload, nodes: sanitizedNodes },
@ -26,5 +16,19 @@ export const actionSanitizer = <A extends AnyAction>(action: A): A => {
} }
} }
if (receivedOpenAPISchema.fulfilled.match(action)) {
return {
...action,
payload: '<OpenAPI schema omitted>',
};
}
if (nodeTemplatesBuilt.match(action)) {
return {
...action,
payload: '<Node templates omitted>',
};
}
return action; return action;
}; };

View File

@ -82,6 +82,7 @@ import {
addImageRemovedFromBoardFulfilledListener, addImageRemovedFromBoardFulfilledListener,
addImageRemovedFromBoardRejectedListener, addImageRemovedFromBoardRejectedListener,
} from './listeners/imageRemovedFromBoard'; } from './listeners/imageRemovedFromBoard';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -205,3 +206,6 @@ addImageAddedToBoardRejectedListener();
addImageRemovedFromBoardFulfilledListener(); addImageRemovedFromBoardFulfilledListener();
addImageRemovedFromBoardRejectedListener(); addImageRemovedFromBoardRejectedListener();
addBoardIdSelectedListener(); addBoardIdSelectedListener();
// Node schemas
addReceivedOpenAPISchemaListener();

View File

@ -0,0 +1,35 @@
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { parseSchema } from 'features/nodes/util/parseSchema';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { size } from 'lodash-es';
const schemaLog = log.child({ namespace: 'schema' });
export const addReceivedOpenAPISchemaListener = () => {
startAppListening({
actionCreator: receivedOpenAPISchema.fulfilled,
effect: (action, { dispatch, getState }) => {
const schemaJSON = action.payload;
schemaLog.info({ data: { schemaJSON } }, 'Dereferenced OpenAPI schema');
const nodeTemplates = parseSchema(schemaJSON);
schemaLog.info(
{ data: { nodeTemplates } },
`Built ${size(nodeTemplates)} node templates`
);
dispatch(nodeTemplatesBuilt(nodeTemplates));
},
});
startAppListening({
actionCreator: receivedOpenAPISchema.rejected,
effect: (action, { dispatch, getState }) => {
schemaLog.error('Problem dereferencing OpenAPI Schema');
},
});
};

View File

@ -3,7 +3,7 @@ import { startAppListening } from '..';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { nodesSelecter } from 'features/nodes/store/nodesSlice'; import { nodesSelector } from 'features/nodes/store/nodesSlice';
import { controlNetSelector } from 'features/controlNet/store/controlNetSlice'; import { controlNetSelector } from 'features/controlNet/store/controlNetSlice';
import { forEach, uniqBy } from 'lodash-es'; import { forEach, uniqBy } from 'lodash-es';
import { imageUrlsReceived } from 'services/api/thunks/image'; import { imageUrlsReceived } from 'services/api/thunks/image';
@ -16,7 +16,7 @@ const selectAllUsedImages = createSelector(
[ [
generationSelector, generationSelector,
canvasSelector, canvasSelector,
nodesSelecter, nodesSelector,
controlNetSelector, controlNetSelector,
selectImagesEntities, selectImagesEntities,
], ],

View File

@ -22,6 +22,7 @@ import boardsReducer from 'features/gallery/store/boardSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import uiReducer from 'features/ui/store/uiSlice'; import uiReducer from 'features/ui/store/uiSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice';
import { listenerMiddleware } from './middleware/listenerMiddleware'; import { listenerMiddleware } from './middleware/listenerMiddleware';
@ -48,6 +49,7 @@ const allReducers = {
controlNet: controlNetReducer, controlNet: controlNetReducer,
boards: boardsReducer, boards: boardsReducer,
// session: sessionReducer, // session: sessionReducer,
dynamicPrompts: dynamicPromptsReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -65,6 +67,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'system', 'system',
'ui', 'ui',
'controlNet', 'controlNet',
'dynamicPrompts',
// 'boards', // 'boards',
// 'hotkeys', // 'hotkeys',
// 'config', // 'config',
@ -100,3 +103,4 @@ export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>; export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>; export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch; export type AppDispatch = typeof store.dispatch;
export const stateSelector = (state: RootState) => state;

View File

@ -171,6 +171,14 @@ export type AppConfig = {
fineStep: number; fineStep: number;
coarseStep: number; coarseStep: number;
}; };
dynamicPrompts: {
maxPrompts: {
initial: number;
min: number;
sliderMax: number;
inputMax: number;
};
};
}; };
}; };

View File

@ -27,7 +27,6 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
borderWidth: '2px', borderWidth: '2px',
borderColor: 'var(--invokeai-colors-base-800)', borderColor: 'var(--invokeai-colors-base-800)',
color: 'var(--invokeai-colors-base-100)', color: 'var(--invokeai-colors-base-100)',
padding: 10,
paddingRight: 24, paddingRight: 24,
fontWeight: 600, fontWeight: 600,
'&:hover': { borderColor: 'var(--invokeai-colors-base-700)' }, '&:hover': { borderColor: 'var(--invokeai-colors-base-700)' },

View File

@ -34,6 +34,10 @@ const IAIMantineSelect = (props: IAISelectProps) => {
'&:focus': { '&:focus': {
borderColor: 'var(--invokeai-colors-accent-600)', borderColor: 'var(--invokeai-colors-accent-600)',
}, },
'&:disabled': {
backgroundColor: 'var(--invokeai-colors-base-700)',
color: 'var(--invokeai-colors-base-400)',
},
}, },
dropdown: { dropdown: {
backgroundColor: 'var(--invokeai-colors-base-800)', backgroundColor: 'var(--invokeai-colors-base-800)',
@ -64,7 +68,7 @@ const IAIMantineSelect = (props: IAISelectProps) => {
}, },
}, },
rightSection: { rightSection: {
width: 24, width: 32,
}, },
})} })}
{...rest} {...rest}

View File

@ -41,7 +41,15 @@ const IAISwitch = (props: Props) => {
{...formControlProps} {...formControlProps}
> >
{label && ( {label && (
<FormLabel my={1} flexGrow={1} {...formLabelProps}> <FormLabel
my={1}
flexGrow={1}
sx={{
cursor: isDisabled ? 'not-allowed' : 'pointer',
...formLabelProps?.sx,
}}
{...formLabelProps}
>
{label} {label}
</FormLabel> </FormLabel>
)} )}

View File

@ -9,10 +9,12 @@ type IAICanvasImageProps = {
}; };
const IAICanvasImage = (props: IAICanvasImageProps) => { const IAICanvasImage = (props: IAICanvasImageProps) => {
const { width, height, x, y, imageName } = props.canvasImage; const { width, height, x, y, imageName } = props.canvasImage;
const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken); const { currentData: imageDTO, isError } = useGetImageDTOQuery(
imageName ?? skipToken
);
const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous'); const [image] = useImage(imageDTO?.image_url ?? '', 'anonymous');
if (!imageDTO) { if (isError) {
return <Rect x={x} y={y} width={width} height={height} fill="red" />; return <Rect x={x} y={y} width={width} height={height} fill="red" />;
} }

View File

@ -1,26 +1,27 @@
import { Box, ChakraProps, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import { import {
ControlNetConfig, ControlNetConfig,
controlNetAdded, controlNetAdded,
controlNetRemoved, controlNetRemoved,
controlNetToggled, controlNetToggled,
} from '../store/controlNetSlice'; } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import ParamControlNetModel from './parameters/ParamControlNetModel'; import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight'; import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { Flex, Box, ChakraProps } from '@chakra-ui/react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ControlNetImagePreview from './ControlNetImagePreview';
import IAIIconButton from 'common/components/IAIIconButton';
import { v4 as uuidv4 } from 'uuid';
import { useToggle } from 'react-use';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import IAISwitch from 'common/components/IAISwitch';
import { ChevronUpIcon } from '@chakra-ui/icons'; import { ChevronUpIcon } from '@chakra-ui/icons';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig'; import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 }; const expandedControlImageSx: ChakraProps['sx'] = { maxH: 96 };
@ -36,6 +37,7 @@ const ControlNet = (props: ControlNetProps) => {
weight, weight,
beginStepPct, beginStepPct,
endStepPct, endStepPct,
controlMode,
controlImage, controlImage,
processedControlImage, processedControlImage,
processorNode, processorNode,
@ -137,48 +139,54 @@ const ControlNet = (props: ControlNetProps) => {
</Flex> </Flex>
{isEnabled && ( {isEnabled && (
<> <>
<Flex sx={{ gap: 4, w: 'full' }}> <Flex sx={{ w: 'full', flexDirection: 'column' }}>
<Flex <Flex sx={{ gap: 4, w: 'full' }}>
sx={{
flexDir: 'column',
gap: 2,
w: 'full',
h: isExpanded ? 28 : 24,
paddingInlineStart: 1,
paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}}
>
<ParamControlNetWeight
controlNetId={controlNetId}
weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/>
</Flex>
{!isExpanded && (
<Flex <Flex
sx={{ sx={{
alignItems: 'center', flexDir: 'column',
justifyContent: 'center', gap: 3,
h: 24, w: 'full',
w: 24, paddingInlineStart: 1,
aspectRatio: '1/1', paddingInlineEnd: isExpanded ? 1 : 0,
pb: 2,
justifyContent: 'space-between',
}} }}
> >
<ControlNetImagePreview <ParamControlNetWeight
controlNet={props.controlNet} controlNetId={controlNetId}
height={24} weight={weight}
mini={!isExpanded}
/>
<ParamControlNetBeginEnd
controlNetId={controlNetId}
beginStepPct={beginStepPct}
endStepPct={endStepPct}
mini={!isExpanded}
/> />
</Flex> </Flex>
)} {!isExpanded && (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
h: 24,
w: 24,
aspectRatio: '1/1',
}}
>
<ControlNetImagePreview
controlNet={props.controlNet}
height={24}
/>
</Flex>
)}
</Flex>
<ParamControlNetControlMode
controlNetId={controlNetId}
controlMode={controlMode}
/>
</Flex> </Flex>
{isExpanded && ( {isExpanded && (
<> <>
<Box mt={2}> <Box mt={2}>

View File

@ -0,0 +1,45 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import {
ControlModes,
controlNetControlModeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
type ParamControlNetControlModeProps = {
controlNetId: string;
controlMode: string;
};
const CONTROL_MODE_DATA = [
{ label: 'Balanced', value: 'balanced' },
{ label: 'Prompt', value: 'more_prompt' },
{ label: 'Control', value: 'more_control' },
{ label: 'Mega Control', value: 'unbalanced' },
];
export default function ParamControlNetControlMode(
props: ParamControlNetControlModeProps
) {
const { controlNetId, controlMode = false } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleControlModeChange = useCallback(
(controlMode: ControlModes) => {
dispatch(controlNetControlModeChanged({ controlNetId, controlMode }));
},
[controlNetId, dispatch]
);
return (
<IAIMantineSelect
label={t('parameters.controlNetControlMode')}
data={CONTROL_MODE_DATA}
value={String(controlMode)}
onChange={handleControlModeChange}
/>
);
}

View File

@ -1,6 +1,5 @@
import { import {
ControlNetProcessorType, ControlNetProcessorType,
RequiredCannyImageProcessorInvocation,
RequiredControlNetProcessorNode, RequiredControlNetProcessorNode,
} from './types'; } from './types';
@ -23,7 +22,7 @@ type ControlNetProcessorsDict = Record<
* *
* TODO: Generate from the OpenAPI schema * TODO: Generate from the OpenAPI schema
*/ */
export const CONTROLNET_PROCESSORS = { export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
none: { none: {
type: 'none', type: 'none',
label: 'none', label: 'none',
@ -174,6 +173,8 @@ export const CONTROLNET_PROCESSORS = {
}, },
}; };
type ControlNetModelsDict = Record<string, ControlNetModel>;
type ControlNetModel = { type ControlNetModel = {
type: string; type: string;
label: string; label: string;
@ -181,7 +182,7 @@ type ControlNetModel = {
defaultProcessor?: ControlNetProcessorType; defaultProcessor?: ControlNetProcessorType;
}; };
export const CONTROLNET_MODELS = { export const CONTROLNET_MODELS: ControlNetModelsDict = {
'lllyasviel/control_v11p_sd15_canny': { 'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny', type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny', label: 'Canny',
@ -190,6 +191,7 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_inpaint': { 'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint', type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint', label: 'Inpaint',
defaultProcessor: 'none',
}, },
'lllyasviel/control_v11p_sd15_mlsd': { 'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd', type: 'lllyasviel/control_v11p_sd15_mlsd',
@ -209,6 +211,7 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_seg': { 'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg', type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segmentation', label: 'Segmentation',
defaultProcessor: 'none',
}, },
'lllyasviel/control_v11p_sd15_lineart': { 'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart', type: 'lllyasviel/control_v11p_sd15_lineart',
@ -223,6 +226,7 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11p_sd15_scribble': { 'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble', type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble', label: 'Scribble',
defaultProcessor: 'none',
}, },
'lllyasviel/control_v11p_sd15_softedge': { 'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge', type: 'lllyasviel/control_v11p_sd15_softedge',
@ -242,10 +246,12 @@ export const CONTROLNET_MODELS = {
'lllyasviel/control_v11f1e_sd15_tile': { 'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile', type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)', label: 'Tile (experimental)',
defaultProcessor: 'none',
}, },
'lllyasviel/control_v11e_sd15_ip2p': { 'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p', type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)', label: 'Pix2Pix (experimental)',
defaultProcessor: 'none',
}, },
'CrucibleAI/ControlNetMediaPipeFace': { 'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace', type: 'CrucibleAI/ControlNetMediaPipeFace',

View File

@ -18,12 +18,19 @@ import { forEach } from 'lodash-es';
import { isAnySessionRejected } from 'services/api/thunks/session'; import { isAnySessionRejected } from 'services/api/thunks/session';
import { appSocketInvocationError } from 'services/events/actions'; import { appSocketInvocationError } from 'services/events/actions';
export type ControlModes =
| 'balanced'
| 'more_prompt'
| 'more_control'
| 'unbalanced';
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type, model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
controlMode: 'balanced',
controlImage: null, controlImage: null,
processedControlImage: null, processedControlImage: null,
processorType: 'canny_image_processor', processorType: 'canny_image_processor',
@ -39,6 +46,7 @@ export type ControlNetConfig = {
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
controlMode: ControlModes;
controlImage: string | null; controlImage: string | null;
processedControlImage: string | null; processedControlImage: string | null;
processorType: ControlNetProcessorType; processorType: ControlNetProcessorType;
@ -181,6 +189,13 @@ export const controlNetSlice = createSlice({
const { controlNetId, endStepPct } = action.payload; const { controlNetId, endStepPct } = action.payload;
state.controlNets[controlNetId].endStepPct = endStepPct; state.controlNets[controlNetId].endStepPct = endStepPct;
}, },
controlNetControlModeChanged: (
state,
action: PayloadAction<{ controlNetId: string; controlMode: ControlModes }>
) => {
const { controlNetId, controlMode } = action.payload;
state.controlNets[controlNetId].controlMode = controlMode;
},
controlNetProcessorParamsChanged: ( controlNetProcessorParamsChanged: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -307,6 +322,7 @@ export const {
controlNetWeightChanged, controlNetWeightChanged,
controlNetBeginStepPctChanged, controlNetBeginStepPctChanged,
controlNetEndStepPctChanged, controlNetEndStepPctChanged,
controlNetControlModeChanged,
controlNetProcessorParamsChanged, controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged, controlNetProcessorTypeChanged,
controlNetReset, controlNetReset,

View File

@ -0,0 +1,45 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useCallback } from 'react';
import { isEnabledToggled } from '../store/slice';
import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts';
import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial';
import { Flex } from '@chakra-ui/react';
const selector = createSelector(
stateSelector,
(state) => {
const { isEnabled } = state.dynamicPrompts;
return { isEnabled };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCollapse = () => {
const dispatch = useAppDispatch();
const { isEnabled } = useAppSelector(selector);
const handleToggleIsEnabled = useCallback(() => {
dispatch(isEnabledToggled());
}, [dispatch]);
return (
<IAICollapse
isOpen={isEnabled}
onToggle={handleToggleIsEnabled}
label="Dynamic Prompts"
withSwitch
>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamDynamicPromptsCombinatorial />
<ParamDynamicPromptsMaxPrompts />
</Flex>
</IAICollapse>
);
};
export default ParamDynamicPromptsCollapse;

View File

@ -0,0 +1,36 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { combinatorialToggled } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
import IAISwitch from 'common/components/IAISwitch';
const selector = createSelector(
stateSelector,
(state) => {
const { combinatorial } = state.dynamicPrompts;
return { combinatorial };
},
defaultSelectorOptions
);
const ParamDynamicPromptsCombinatorial = () => {
const { combinatorial } = useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(() => {
dispatch(combinatorialToggled());
}, [dispatch]);
return (
<IAISwitch
label="Combinatorial Generation"
isChecked={combinatorial}
onChange={handleChange}
/>
);
};
export default ParamDynamicPromptsCombinatorial;

View File

@ -0,0 +1,55 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { maxPromptsChanged, maxPromptsReset } from '../store/slice';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useCallback } from 'react';
import { stateSelector } from 'app/store/store';
const selector = createSelector(
stateSelector,
(state) => {
const { maxPrompts, combinatorial } = state.dynamicPrompts;
const { min, sliderMax, inputMax } =
state.config.sd.dynamicPrompts.maxPrompts;
return { maxPrompts, min, sliderMax, inputMax, combinatorial };
},
defaultSelectorOptions
);
const ParamDynamicPromptsMaxPrompts = () => {
const { maxPrompts, min, sliderMax, inputMax, combinatorial } =
useAppSelector(selector);
const dispatch = useAppDispatch();
const handleChange = useCallback(
(v: number) => {
dispatch(maxPromptsChanged(v));
},
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(maxPromptsReset());
}, [dispatch]);
return (
<IAISlider
label="Max Prompts"
isDisabled={!combinatorial}
min={min}
max={sliderMax}
value={maxPrompts}
onChange={handleChange}
sliderNumberInputProps={{ max: inputMax }}
withSliderMarks
withInput
inputReadOnly
withReset
handleReset={handleReset}
/>
);
};
export default ParamDynamicPromptsMaxPrompts;

View File

@ -0,0 +1,50 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
export interface DynamicPromptsState {
isEnabled: boolean;
maxPrompts: number;
combinatorial: boolean;
}
export const initialDynamicPromptsState: DynamicPromptsState = {
isEnabled: false,
maxPrompts: 100,
combinatorial: true,
};
const initialState: DynamicPromptsState = initialDynamicPromptsState;
export const dynamicPromptsSlice = createSlice({
name: 'dynamicPrompts',
initialState,
reducers: {
maxPromptsChanged: (state, action: PayloadAction<number>) => {
state.maxPrompts = action.payload;
},
maxPromptsReset: (state) => {
state.maxPrompts = initialDynamicPromptsState.maxPrompts;
},
combinatorialToggled: (state) => {
state.combinatorial = !state.combinatorial;
},
isEnabledToggled: (state) => {
state.isEnabled = !state.isEnabled;
},
},
extraReducers: (builder) => {
//
},
});
export const {
isEnabledToggled,
maxPromptsChanged,
maxPromptsReset,
combinatorialToggled,
} = dynamicPromptsSlice.actions;
export default dynamicPromptsSlice.reducer;
export const dynamicPromptsSelector = (state: RootState) =>
state.dynamicPrompts;

View File

@ -1,28 +1,41 @@
import 'reactflow/dist/style.css'; import 'reactflow/dist/style.css';
import { memo, useCallback } from 'react'; import { useCallback, forwardRef } from 'react';
import { import { Flex, Text } from '@chakra-ui/react';
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
} from '@chakra-ui/react';
import { FaEllipsisV } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { nodeAdded } from '../store/nodesSlice'; import { nodeAdded, nodesSelector } from '../store/nodesSlice';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { RootState } from 'app/store/store';
import { useBuildInvocation } from '../hooks/useBuildInvocation'; import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { AnyInvocationType } from 'services/events/types'; import { AnyInvocationType } from 'services/events/types';
import IAIIconButton from 'common/components/IAIIconButton';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect';
type NodeTemplate = {
label: string;
value: string;
description: string;
};
const selector = createSelector(
nodesSelector,
(nodes) => {
const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => {
return {
label: template.title,
value: template.type,
description: template.description,
};
});
return { data };
},
defaultSelectorOptions
);
const AddNodeMenu = () => { const AddNodeMenu = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useAppSelector(selector);
const invocationTemplates = useAppSelector(
(state: RootState) => state.nodes.invocationTemplates
);
const buildInvocation = useBuildInvocation(); const buildInvocation = useBuildInvocation();
@ -46,23 +59,52 @@ const AddNodeMenu = () => {
); );
return ( return (
<Menu isLazy> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<MenuButton <IAIMantineMultiSelect
as={IAIIconButton} selectOnBlur={false}
aria-label="Add Node" placeholder="Add Node"
icon={<FaEllipsisV />} value={[]}
data={data}
maxDropdownHeight={400}
nothingFound="No matching nodes"
itemComponent={SelectItem}
filter={(value, selected, item: NodeTemplate) =>
item.label.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
item.description.toLowerCase().includes(value.toLowerCase().trim())
}
onChange={(v) => {
v[0] && addNode(v[0] as AnyInvocationType);
}}
sx={{
width: '18rem',
}}
/> />
<MenuList overflowY="scroll" height={400}> </Flex>
{map(invocationTemplates, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
); );
}; };
export default memo(AddNodeMenu); interface ItemProps extends React.ComponentPropsWithoutRef<'div'> {
value: string;
label: string;
description: string;
}
const SelectItem = forwardRef<HTMLDivElement, ItemProps>(
({ label, description, ...others }: ItemProps, ref) => {
return (
<div ref={ref} {...others}>
<div>
<Text>{label}</Text>
<Text size="xs" color="base.600">
{description}
</Text>
</div>
</div>
);
}
);
SelectItem.displayName = 'SelectItem';
export default AddNodeMenu;

View File

@ -23,7 +23,7 @@ const ModelInputFieldComponent = (
const { t } = useTranslation(); const { t } = useTranslation();
const { data: pipelineModels } = useListModelsQuery({ const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline', model_type: 'main',
}); });
const data = useMemo(() => { const data = useMemo(() => {

View File

@ -1,10 +1,10 @@
import { memo } from 'react'; import { memo } from 'react';
import { Panel } from 'reactflow'; import { Panel } from 'reactflow';
import NodeSearch from '../search/NodeSearch'; import AddNodeMenu from '../AddNodeMenu';
const TopLeftPanel = () => ( const TopLeftPanel = () => (
<Panel position="top-left"> <Panel position="top-left">
<NodeSearch /> <AddNodeMenu />
</Panel> </Panel>
); );

View File

@ -14,9 +14,6 @@ import {
import { ImageField } from 'services/api/types'; import { ImageField } from 'services/api/types';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { InvocationTemplate, InvocationValue } from '../types/types'; import { InvocationTemplate, InvocationValue } from '../types/types';
import { parseSchema } from '../util/parseSchema';
import { log } from 'app/logging/useLogger';
import { size } from 'lodash-es';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
@ -78,25 +75,17 @@ const nodesSlice = createSlice({
shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => { shouldShowGraphOverlayChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowGraphOverlay = action.payload; state.shouldShowGraphOverlay = action.payload;
}, },
parsedOpenAPISchema: (state, action: PayloadAction<OpenAPIV3.Document>) => { nodeTemplatesBuilt: (
try { state,
const parsedSchema = parseSchema(action.payload); action: PayloadAction<Record<string, InvocationTemplate>>
) => {
// TODO: Achtung! Side effect in a reducer! state.invocationTemplates = action.payload;
log.info(
{ namespace: 'schema', nodes: parsedSchema },
`Parsed ${size(parsedSchema)} nodes`
);
state.invocationTemplates = parsedSchema;
} catch (err) {
console.error(err);
}
}, },
nodeEditorReset: () => { nodeEditorReset: () => {
return { ...initialNodesState }; return { ...initialNodesState };
}, },
}, },
extraReducers(builder) { extraReducers: (builder) => {
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => { builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload; state.schema = action.payload;
}); });
@ -112,10 +101,10 @@ export const {
connectionStarted, connectionStarted,
connectionEnded, connectionEnded,
shouldShowGraphOverlayChanged, shouldShowGraphOverlayChanged,
parsedOpenAPISchema, nodeTemplatesBuilt,
nodeEditorReset, nodeEditorReset,
} = nodesSlice.actions; } = nodesSlice.actions;
export default nodesSlice.reducer; export default nodesSlice.reducer;
export const nodesSelecter = (state: RootState) => state.nodes; export const nodesSelector = (state: RootState) => state.nodes;

View File

@ -34,12 +34,10 @@ export type InvocationTemplate = {
* Array of invocation inputs * Array of invocation inputs
*/ */
inputs: Record<string, InputFieldTemplate>; inputs: Record<string, InputFieldTemplate>;
// inputs: InputField[];
/** /**
* Array of the invocation outputs * Array of the invocation outputs
*/ */
outputs: Record<string, OutputFieldTemplate>; outputs: Record<string, OutputFieldTemplate>;
// outputs: OutputField[];
}; };
export type FieldUIConfig = { export type FieldUIConfig = {
@ -335,7 +333,7 @@ export type TypeHints = {
}; };
export type InvocationSchemaExtra = { export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation output: OpenAPIV3.SchemaObject; // the output of the invocation
ui?: { ui?: {
tags?: string[]; tags?: string[];
type_hints?: TypeHints; type_hints?: TypeHints;

View File

@ -1,5 +1,5 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { filter, forEach, size } from 'lodash-es'; import { filter } from 'lodash-es';
import { CollectInvocation, ControlNetInvocation } from 'services/api/types'; import { CollectInvocation, ControlNetInvocation } from 'services/api/types';
import { NonNullableGraph } from '../types/types'; import { NonNullableGraph } from '../types/types';
import { CONTROL_NET_COLLECT } from './graphBuilders/constants'; import { CONTROL_NET_COLLECT } from './graphBuilders/constants';
@ -19,9 +19,9 @@ export const addControlNetToLinearGraph = (
(c.processorType === 'none' && Boolean(c.controlImage))) (c.processorType === 'none' && Boolean(c.controlImage)))
); );
// Add ControlNet if (isControlNetEnabled && Boolean(validControlNets.length)) {
if (isControlNetEnabled && validControlNets.length > 0) { if (validControlNets.length > 1) {
if (size(controlNets) > 1) { // We have multiple controlnets, add ControlNet collector
const controlNetIterateNode: CollectInvocation = { const controlNetIterateNode: CollectInvocation = {
id: CONTROL_NET_COLLECT, id: CONTROL_NET_COLLECT,
type: 'collect', type: 'collect',
@ -36,29 +36,25 @@ export const addControlNetToLinearGraph = (
}); });
} }
forEach(controlNets, (controlNet) => { validControlNets.forEach((controlNet) => {
const { const {
controlNetId, controlNetId,
isEnabled,
controlImage, controlImage,
processedControlImage, processedControlImage,
beginStepPct, beginStepPct,
endStepPct, endStepPct,
controlMode,
model, model,
processorType, processorType,
weight, weight,
} = controlNet; } = controlNet;
if (!isEnabled) {
// Skip disabled ControlNets
return;
}
const controlNetNode: ControlNetInvocation = { const controlNetNode: ControlNetInvocation = {
id: `control_net_${controlNetId}`, id: `control_net_${controlNetId}`,
type: 'controlnet', type: 'controlnet',
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,
control_mode: controlMode,
control_model: model as ControlNetInvocation['control_model'], control_model: model as ControlNetInvocation['control_model'],
control_weight: weight, control_weight: weight,
}; };
@ -80,7 +76,8 @@ export const addControlNetToLinearGraph = (
graph.nodes[controlNetNode.id] = controlNetNode; graph.nodes[controlNetNode.id] = controlNetNode;
if (size(controlNets) > 1) { if (validControlNets.length > 1) {
// if we have multiple controlnets, link to the collector
graph.edges.push({ graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' }, source: { node_id: controlNetNode.id, field: 'control' },
destination: { destination: {
@ -89,6 +86,7 @@ export const addControlNetToLinearGraph = (
}, },
}); });
} else { } else {
// otherwise, link directly to the base node
graph.edges.push({ graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' }, source: { node_id: controlNetNode.id, field: 'control' },
destination: { destination: {

View File

@ -349,21 +349,11 @@ export const getFieldType = (
if (typeHints && name in typeHints) { if (typeHints && name in typeHints) {
rawFieldType = typeHints[name]; rawFieldType = typeHints[name];
} else if (!schemaObject.type) { } else if (!schemaObject.type && schemaObject.allOf) {
// if schemaObject has no type, then it should have one of allOf, anyOf, oneOf // if schemaObject has no type, then it should have one of allOf
if (schemaObject.allOf) { rawFieldType =
rawFieldType = refObjectToFieldType( (schemaObject.allOf[0] as OpenAPIV3.SchemaObject).title ??
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject 'Missing Field Type';
);
} else if (schemaObject.anyOf) {
rawFieldType = refObjectToFieldType(
schemaObject.anyOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.oneOf) {
rawFieldType = refObjectToFieldType(
schemaObject.oneOf![0] as OpenAPIV3.ReferenceObject
);
}
} else if (schemaObject.enum) { } else if (schemaObject.enum) {
rawFieldType = 'enum'; rawFieldType = 'enum';
} else if (schemaObject.type) { } else if (schemaObject.type) {

View File

@ -0,0 +1,153 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import {
DynamicPromptInvocation,
IterateInvocation,
NoiseInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import {
DYNAMIC_PROMPT,
ITERATE,
NOISE,
POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
} from './constants';
import { unset } from 'lodash-es';
export const addDynamicPromptsToGraph = (
graph: NonNullableGraph,
state: RootState
): void => {
const { positivePrompt, iterations, seed, shouldRandomizeSeed } =
state.generation;
const {
combinatorial,
isEnabled: isDynamicPromptsEnabled,
maxPrompts,
} = state.dynamicPrompts;
if (isDynamicPromptsEnabled) {
// iteration is handled via dynamic prompts
unset(graph.nodes[POSITIVE_CONDITIONING], 'prompt');
const dynamicPromptNode: DynamicPromptInvocation = {
id: DYNAMIC_PROMPT,
type: 'dynamic_prompt',
max_prompts: combinatorial ? maxPrompts : iterations,
combinatorial,
prompt: positivePrompt,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
graph.nodes[DYNAMIC_PROMPT] = dynamicPromptNode;
graph.nodes[ITERATE] = iterateNode;
// connect dynamic prompts to compel nodes
graph.edges.push(
{
source: {
node_id: DYNAMIC_PROMPT,
field: 'prompt_collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'prompt',
},
}
);
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: NOISE, field: 'seed' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[NOISE] as NoiseInvocation).seed = seed;
}
} else {
const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE,
type: 'range_of_size',
size: iterations,
step: 1,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
graph.nodes[ITERATE] = iterateNode;
graph.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
graph.edges.push({
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
});
graph.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
});
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
rangeOfSizeNode.start = seed;
}
}
};

View File

@ -2,6 +2,7 @@ import { RootState } from 'app/store/store';
import { import {
ImageDTO, ImageDTO,
ImageResizeInvocation, ImageResizeInvocation,
ImageToLatentsInvocation,
RandomIntInvocation, RandomIntInvocation,
RangeOfSizeInvocation, RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
@ -10,7 +11,7 @@ import { log } from 'app/logging/useLogger';
import { import {
ITERATE, ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -24,6 +25,7 @@ import {
import { set } from 'lodash-es'; import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -75,31 +77,19 @@ export const buildCanvasImageToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
id: LATENTS_TO_LATENTS, id: LATENTS_TO_LATENTS,
@ -120,7 +110,7 @@ export const buildCanvasImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -130,7 +120,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -140,7 +130,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -148,26 +138,6 @@ export const buildCanvasImageToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: LATENTS_TO_LATENTS, node_id: LATENTS_TO_LATENTS,
@ -200,7 +170,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -210,7 +180,7 @@ export const buildCanvasImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -241,26 +211,6 @@ export const buildCanvasImageToImageGraph = (
], ],
}; };
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// handle `fit` // handle `fit`
if (initialImage.width !== width || initialImage.height !== height) { if (initialImage.width !== width || initialImage.height !== height) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
@ -306,9 +256,9 @@ export const buildCanvasImageToImageGraph = (
}); });
} else { } else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', { (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.image_name, image_name: initialImage.image_name,
}); };
// Pass the image's dimensions to the `NOISE` node // Pass the image's dimensions to the `NOISE` node
graph.edges.push({ graph.edges.push({
@ -327,7 +277,10 @@ export const buildCanvasImageToImageGraph = (
}); });
} }
// add controlnet // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph; return graph;

View File

@ -9,7 +9,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { import {
ITERATE, ITERATE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT, RANDOM_INT,
@ -101,9 +101,9 @@ export const buildCanvasInpaintGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[RANGE_OF_SIZE]: { [RANGE_OF_SIZE]: {
@ -142,7 +142,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -152,7 +152,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -162,7 +162,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -172,7 +172,7 @@ export const buildCanvasInpaintGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {

View File

@ -4,7 +4,7 @@ import { RandomIntInvocation, RangeOfSizeInvocation } from 'services/api/types';
import { import {
ITERATE, ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
@ -15,6 +15,7 @@ import {
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
/** /**
* Builds the Canvas tab's Text to Image graph. * Builds the Canvas tab's Text to Image graph.
@ -62,13 +63,6 @@ export const buildCanvasTextToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// start: 0, // seed - must be connected manually
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
@ -82,19 +76,15 @@ export const buildCanvasTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
}, },
edges: [ edges: [
{ {
@ -119,7 +109,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -129,7 +119,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -139,7 +129,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -159,7 +149,7 @@ export const buildCanvasTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -167,26 +157,6 @@ export const buildCanvasTextToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -200,27 +170,10 @@ export const buildCanvasTextToImageGraph = (
], ],
}; };
// handle seed // add dynamic prompts, mutating `graph`
if (shouldRandomizeSeed) { addDynamicPromptsToGraph(graph, state);
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode; // add controlnet, mutating `graph`
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// add controlnet
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
return graph; return graph;

View File

@ -1,28 +1,24 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { import {
ImageResizeInvocation, ImageResizeInvocation,
RandomIntInvocation, ImageToLatentsInvocation,
RangeOfSizeInvocation,
} from 'services/api/types'; } from 'services/api/types';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { import {
ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS, IMAGE_TO_LATENTS,
LATENTS_TO_LATENTS, LATENTS_TO_LATENTS,
RESIZE, RESIZE,
} from './constants'; } from './constants';
import { set } from 'lodash-es';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
const moduleLog = log.child({ namespace: 'nodes' }); const moduleLog = log.child({ namespace: 'nodes' });
@ -44,9 +40,6 @@ export const buildLinearImageToImageGraph = (
shouldFitToWidthHeight, shouldFitToWidthHeight,
width, width,
height, height,
iterations,
seed,
shouldRandomizeSeed,
} = state.generation; } = state.generation;
/** /**
@ -79,31 +72,19 @@ export const buildLinearImageToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// seed - must be connected manually
// start: 0,
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
id: LATENTS_TO_LATENTS, id: LATENTS_TO_LATENTS,
@ -124,7 +105,7 @@ export const buildLinearImageToImageGraph = (
edges: [ edges: [
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -134,7 +115,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -144,7 +125,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -152,26 +133,6 @@ export const buildLinearImageToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: LATENTS_TO_LATENTS, node_id: LATENTS_TO_LATENTS,
@ -204,7 +165,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -214,7 +175,7 @@ export const buildLinearImageToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -245,26 +206,6 @@ export const buildLinearImageToImageGraph = (
], ],
}; };
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode;
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// handle `fit` // handle `fit`
if ( if (
shouldFitToWidthHeight && shouldFitToWidthHeight &&
@ -313,9 +254,9 @@ export const buildLinearImageToImageGraph = (
}); });
} else { } else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly // We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
set(graph.nodes[IMAGE_TO_LATENTS], 'image', { (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.imageName, image_name: initialImage.imageName,
}); };
// Pass the image's dimensions to the `NOISE` node // Pass the image's dimensions to the `NOISE` node
graph.edges.push({ graph.edges.push({
@ -334,7 +275,10 @@ export const buildLinearImageToImageGraph = (
}); });
} }
// add controlnet // add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state); addControlNetToLinearGraph(graph, LATENTS_TO_LATENTS, state);
return graph; return graph;

View File

@ -1,33 +1,20 @@
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types'; import { NonNullableGraph } from 'features/nodes/types/types';
import { import {
BaseModelType,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api/types';
import {
ITERATE,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
MODEL_LOADER, PIPELINE_MODEL_LOADER,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RANDOM_INT,
RANGE_OF_SIZE,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
} from './constants'; } from './constants';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField'; import { modelIdToPipelineModelField } from '../modelIdToPipelineModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
type TextToImageGraphOverrides = {
width: number;
height: number;
};
export const buildLinearTextToImageGraph = ( export const buildLinearTextToImageGraph = (
state: RootState, state: RootState
overrides?: TextToImageGraphOverrides
): NonNullableGraph => { ): NonNullableGraph => {
const { const {
positivePrompt, positivePrompt,
@ -38,9 +25,6 @@ export const buildLinearTextToImageGraph = (
steps, steps,
width, width,
height, height,
iterations,
seed,
shouldRandomizeSeed,
} = state.generation; } = state.generation;
const model = modelIdToPipelineModelField(modelId); const model = modelIdToPipelineModelField(modelId);
@ -68,18 +52,11 @@ export const buildLinearTextToImageGraph = (
id: NEGATIVE_CONDITIONING, id: NEGATIVE_CONDITIONING,
prompt: negativePrompt, prompt: negativePrompt,
}, },
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
// start: 0, // seed - must be connected manually
size: iterations,
step: 1,
},
[NOISE]: { [NOISE]: {
type: 'noise', type: 'noise',
id: NOISE, id: NOISE,
width: overrides?.width || width, width,
height: overrides?.height || height, height,
}, },
[TEXT_TO_LATENTS]: { [TEXT_TO_LATENTS]: {
type: 't2l', type: 't2l',
@ -88,19 +65,15 @@ export const buildLinearTextToImageGraph = (
scheduler, scheduler,
steps, steps,
}, },
[MODEL_LOADER]: { [PIPELINE_MODEL_LOADER]: {
type: 'pipeline_model_loader', type: 'pipeline_model_loader',
id: MODEL_LOADER, id: PIPELINE_MODEL_LOADER,
model, model,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
}, },
[ITERATE]: {
type: 'iterate',
id: ITERATE,
},
}, },
edges: [ edges: [
{ {
@ -125,7 +98,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -135,7 +108,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'clip', field: 'clip',
}, },
destination: { destination: {
@ -145,7 +118,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'unet', field: 'unet',
}, },
destination: { destination: {
@ -165,7 +138,7 @@ export const buildLinearTextToImageGraph = (
}, },
{ {
source: { source: {
node_id: MODEL_LOADER, node_id: PIPELINE_MODEL_LOADER,
field: 'vae', field: 'vae',
}, },
destination: { destination: {
@ -173,26 +146,6 @@ export const buildLinearTextToImageGraph = (
field: 'vae', field: 'vae',
}, },
}, },
{
source: {
node_id: RANGE_OF_SIZE,
field: 'collection',
},
destination: {
node_id: ITERATE,
field: 'collection',
},
},
{
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -206,27 +159,10 @@ export const buildLinearTextToImageGraph = (
], ],
}; };
// handle seed // add dynamic prompts, mutating `graph`
if (shouldRandomizeSeed) { addDynamicPromptsToGraph(graph, state);
// Random int node to generate the starting seed
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
graph.nodes[RANDOM_INT] = randomIntNode; // add controlnet, mutating `graph`
// Connect random int to the start of the range of size so the range starts on the random first seed
graph.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
} else {
// User specified seed, so set the start of the range of size to the seed
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// add controlnet
addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state); addControlNetToLinearGraph(graph, TEXT_TO_LATENTS, state);
return graph; return graph;

View File

@ -7,12 +7,13 @@ export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int'; export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate'; export const ITERATE = 'iterate';
export const MODEL_LOADER = 'pipeline_model_loader'; export const PIPELINE_MODEL_LOADER = 'pipeline_model_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents'; export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image'; export const RESIZE = 'resize_image';
export const INPAINT = 'inpaint'; export const INPAINT = 'inpaint';
export const CONTROL_NET_COLLECT = 'control_net_collect'; export const CONTROL_NET_COLLECT = 'control_net_collect';
export const DYNAMIC_PROMPT = 'dynamic_prompt';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -1,26 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { CompelInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildCompelNode = (
prompt: string,
state: RootState,
overrides: O.Partial<CompelInvocation, 'deep'> = {}
): CompelInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const { model } = generation;
const compelNode: CompelInvocation = {
id: nodeId,
type: 'compel',
prompt,
model,
};
Object.assign(compelNode, overrides);
return compelNode;
};

View File

@ -1,107 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import {
Edge,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api/types';
import { O } from 'ts-toolbelt';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
export const buildImg2ImgNode = (
state: RootState,
overrides: O.Partial<ImageToImageInvocation, 'deep'> = {}
): ImageToImageInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const activeTabName = activeTabNameSelector(state);
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale,
scheduler,
model,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
initialImage,
} = generation;
// const initialImage = initialImageSelector(state);
const imageToImageNode: ImageToImageInvocation = {
id: nodeId,
type: 'img2img',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler,
model,
strength,
fit,
};
// on Canvas tab, we do not manually specific init image
if (activeTabName !== 'unifiedCanvas') {
if (!initialImage) {
// TODO: handle this more better
throw 'no initial image';
}
imageToImageNode.image = {
image_name: initialImage.imageName,
};
}
if (!shouldRandomizeSeed) {
imageToImageNode.seed = seed;
}
Object.assign(imageToImageNode, overrides);
return imageToImageNode;
};
type hiresReturnType = {
node: Record<string, ImageToImageInvocation>;
edge: Edge;
};
export const buildHiResNode = (
baseNode: Record<string, TextToImageInvocation>,
strength?: number
): hiresReturnType => {
const nodeId = uuidv4();
const baseNodeId = Object.keys(baseNode)[0];
const baseNodeValues = Object.values(baseNode)[0];
return {
node: {
[nodeId]: {
...baseNodeValues,
id: nodeId,
type: 'img2img',
strength,
fit: true,
},
},
edge: {
source: {
field: 'image',
node_id: baseNodeId,
},
destination: {
field: 'image',
node_id: nodeId,
},
},
};
};

View File

@ -1,48 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { InpaintInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildInpaintNode = (
state: RootState,
overrides: O.Partial<InpaintInvocation, 'deep'> = {}
): InpaintInvocation => {
const nodeId = uuidv4();
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale,
scheduler,
model,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
} = state.generation;
const inpaintNode: InpaintInvocation = {
id: nodeId,
type: 'inpaint',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler,
model,
strength,
fit,
};
if (!shouldRandomizeSeed) {
inpaintNode.seed = seed;
}
Object.assign(inpaintNode, overrides);
return inpaintNode;
};

View File

@ -1,13 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { IterateInvocation } from 'services/api/types';
export const buildIterateNode = (): IterateInvocation => {
const nodeId = uuidv4();
return {
id: nodeId,
type: 'iterate',
// collection: [],
// index: 0,
};
};

View File

@ -1,26 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { RandomRangeInvocation, RangeInvocation } from 'services/api/types';
export const buildRangeNode = (
state: RootState
): RangeInvocation | RandomRangeInvocation => {
const nodeId = uuidv4();
const { shouldRandomizeSeed, iterations, seed } = state.generation;
if (shouldRandomizeSeed) {
return {
id: nodeId,
type: 'random_range',
size: iterations,
};
}
return {
id: nodeId,
type: 'range',
start: seed,
stop: seed + iterations,
};
};

View File

@ -1,45 +0,0 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { TextToImageInvocation } from 'services/api/types';
import { O } from 'ts-toolbelt';
export const buildTxt2ImgNode = (
state: RootState,
overrides: O.Partial<TextToImageInvocation, 'deep'> = {}
): TextToImageInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const {
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,
height,
cfgScale: cfg_scale,
scheduler,
shouldRandomizeSeed,
model,
} = generation;
const textToImageNode: NonNullable<TextToImageInvocation> = {
id: nodeId,
type: 'txt2img',
prompt: `${prompt} [${negativePrompt}]`,
steps,
width,
height,
cfg_scale,
scheduler,
model,
};
if (!shouldRandomizeSeed) {
textToImageNode.seed = seed;
}
Object.assign(textToImageNode, overrides);
return textToImageNode;
};

View File

@ -5,127 +5,154 @@ import {
InputFieldTemplate, InputFieldTemplate,
InvocationSchemaObject, InvocationSchemaObject,
InvocationTemplate, InvocationTemplate,
isInvocationSchemaObject,
OutputFieldTemplate, OutputFieldTemplate,
} from '../types/types'; } from '../types/types';
import { import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
buildInputFieldTemplate, import { O } from 'ts-toolbelt';
buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; // recursively exclude all properties of type U from T
type DeepExclude<T, U> = T extends U
? never
: T extends object
? {
[K in keyof T]: DeepExclude<T[K], U>;
}
: T;
// The schema from swagger-parser is dereferenced, and we know `components` and `components.schemas` exist
type DereferencedOpenAPIDocument = DeepExclude<
O.Required<OpenAPIV3.Document, 'schemas' | 'components', 'deep'>,
OpenAPIV3.ReferenceObject
>;
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate']; const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate'];
const invocationDenylist = ['Graph', 'InvocationMeta']; const invocationDenylist = ['Graph', 'InvocationMeta'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => { const nodeFilter = (
schema: DereferencedOpenAPIDocument['components']['schemas'][string],
key: string
) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!invocationDenylist.some((denylistItem) => key.includes(denylistItem));
export const parseSchema = (openAPI: DereferencedOpenAPIDocument) => {
// filter out non-invocation schemas, plus some tricky invocations for now // filter out non-invocation schemas, plus some tricky invocations for now
const filteredSchemas = filter( const filteredSchemas = filter(openAPI.components.schemas, nodeFilter);
openAPI.components!.schemas,
(schema, key) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!invocationDenylist.some((denylistItem) => key.includes(denylistItem))
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
const invocations = filteredSchemas.reduce< const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate> Record<string, InvocationTemplate>
>((acc, schema) => { >((acc, s) => {
// only want SchemaObjects // cast to InvocationSchemaObject, we know the shape
if (isInvocationSchemaObject(schema)) { const schema = s as InvocationSchemaObject;
const type = schema.properties.type.default;
const title = schema.ui?.title ?? schema.title.replace('Invocation', ''); const type = schema.properties.type.default;
const typeHints = schema.ui?.type_hints; const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
const inputs: Record<string, InputFieldTemplate> = {}; const typeHints = schema.ui?.type_hints;
if (type === 'collect') { const inputs: Record<string, InputFieldTemplate> = {};
const itemProperty = schema.properties[
'item'
] as InvocationSchemaObject;
// Handle the special Collect node
inputs.item = {
type: 'item',
name: 'item',
description: itemProperty.description ?? '',
title: 'Collection Item',
inputKind: 'connection',
inputRequirement: 'always',
default: undefined,
};
} else if (type === 'iterate') {
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
inputs.collection = { if (type === 'collect') {
type: 'array', // Special handling for the Collect node
name: 'collection', const itemProperty = schema.properties['item'] as InvocationSchemaObject;
title: itemProperty.title ?? '', inputs.item = {
default: [], type: 'item',
description: itemProperty.description ?? '', name: 'item',
inputRequirement: 'always', description: itemProperty.description ?? '',
inputKind: 'connection', title: 'Collection Item',
}; inputKind: 'connection',
} else { inputRequirement: 'always',
// All other nodes default: undefined,
reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
buildInputFieldTemplate(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
inputs
);
}
const rawOutput = (schema as InvocationSchemaObject).output;
let outputs: Record<string, OutputFieldTemplate>;
// some special handling is needed for collect, iterate and range nodes
if (type === 'iterate') {
// this is guaranteed to be a SchemaObject
const iterationOutput = openAPI.components!.schemas![
'IterateInvocationOutput'
] as OpenAPIV3.SchemaObject;
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
type: 'array',
},
};
} else {
outputs = buildOutputFieldTemplates(rawOutput, openAPI, typeHints);
}
const invocation: InvocationTemplate = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
}; };
} else if (type === 'iterate') {
// Special handling for the Iterate node
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
Object.assign(acc, { [type]: invocation }); inputs.collection = {
type: 'array',
name: 'collection',
title: itemProperty.title ?? '',
default: [],
description: itemProperty.description ?? '',
inputRequirement: 'always',
inputKind: 'connection',
};
} else {
// All other nodes
reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
buildInputFieldTemplate(property, propertyName, typeHints);
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
inputs
);
} }
let outputs: Record<string, OutputFieldTemplate>;
if (type === 'iterate') {
// Special handling for the Iterate node output
const iterationOutput =
openAPI.components.schemas['IterateInvocationOutput'];
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
type: 'array',
},
};
} else {
// All other node outputs
outputs = reduce(
schema.output.properties as OpenAPIV3.SchemaObject,
(outputsAccumulator, property, propertyName) => {
if (!['type', 'id'].includes(propertyName)) {
const fieldType = getFieldType(property, propertyName, typeHints);
outputsAccumulator[propertyName] = {
name: propertyName,
title: property.title ?? '',
description: property.description ?? '',
type: fieldType,
};
}
return outputsAccumulator;
},
{} as Record<string, OutputFieldTemplate>
);
}
const invocation: InvocationTemplate = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
};
Object.assign(acc, { [type]: invocation });
return acc; return acc;
}, {}); }, {});

View File

@ -1,4 +1,5 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput'; import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
@ -10,27 +11,27 @@ import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createSelector( const selector = createSelector([stateSelector], (state) => {
[generationSelector, configSelector, uiSelector, hotkeysSelector], const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
(generation, config, ui, hotkeys) => { state.config.sd.iterations;
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = const { iterations } = state.generation;
config.sd.iterations; const { shouldUseSliders } = state.ui;
const { iterations } = generation; const isDisabled =
const { shouldUseSliders } = ui; state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial;
const step = hotkeys.shift ? fineStep : coarseStep; const step = state.hotkeys.shift ? fineStep : coarseStep;
return { return {
iterations, iterations,
initial, initial,
min, min,
sliderMax, sliderMax,
inputMax, inputMax,
step, step,
shouldUseSliders, shouldUseSliders,
}; isDisabled,
} };
); });
const ParamIterations = () => { const ParamIterations = () => {
const { const {
@ -41,6 +42,7 @@ const ParamIterations = () => {
inputMax, inputMax,
step, step,
shouldUseSliders, shouldUseSliders,
isDisabled,
} = useAppSelector(selector); } = useAppSelector(selector);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -58,6 +60,7 @@ const ParamIterations = () => {
return shouldUseSliders ? ( return shouldUseSliders ? (
<IAISlider <IAISlider
isDisabled={isDisabled}
label={t('parameters.images')} label={t('parameters.images')}
step={step} step={step}
min={min} min={min}
@ -72,6 +75,7 @@ const ParamIterations = () => {
/> />
) : ( ) : (
<IAINumberInput <IAINumberInput
isDisabled={isDisabled}
label={t('parameters.images')} label={t('parameters.images')}
step={step} step={step}
min={min} min={min}

View File

@ -24,7 +24,7 @@ const ModelSelect = () => {
); );
const { data: pipelineModels } = useListModelsQuery({ const { data: pipelineModels } = useListModelsQuery({
model_type: 'pipeline', model_type: 'main',
}); });
const data = useMemo(() => { const data = useMemo(() => {

View File

@ -60,6 +60,14 @@ export const initialConfigState: AppConfig = {
fineStep: 0.01, fineStep: 0.01,
coarseStep: 0.05, coarseStep: 0.05,
}, },
dynamicPrompts: {
maxPrompts: {
initial: 100,
min: 1,
sliderMax: 1000,
inputMax: 10000,
},
},
}, },
}; };

View File

@ -4,7 +4,6 @@ import * as InvokeAI from 'app/types/invokeai';
import { InvokeLogLevel } from 'app/logging/useLogger'; import { InvokeLogLevel } from 'app/logging/useLogger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { TFuncKey, t } from 'i18next'; import { TFuncKey, t } from 'i18next';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { import {
@ -26,6 +25,7 @@ import {
} from 'services/api/thunks/session'; } from 'services/api/thunks/session';
import { makeToast } from '../../../app/components/Toaster'; import { makeToast } from '../../../app/components/Toaster';
import { LANGUAGES } from '../components/LanguagePicker'; import { LANGUAGES } from '../components/LanguagePicker';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -382,7 +382,7 @@ export const systemSlice = createSlice({
/** /**
* OpenAPI schema was parsed * OpenAPI schema was parsed
*/ */
builder.addCase(parsedOpenAPISchema, (state) => { builder.addCase(nodeTemplatesBuilt, (state) => {
state.wasSchemaParsed = true; state.wasSchemaParsed = true;
}); });

View File

@ -17,7 +17,6 @@ import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react'; import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { MdDeviceHub, MdGridOn } from 'react-icons/md'; import { MdDeviceHub, MdGridOn } from 'react-icons/md';
import { GoTextSize } from 'react-icons/go';
import { import {
activeTabIndexSelector, activeTabIndexSelector,
activeTabNameSelector, activeTabNameSelector,
@ -33,7 +32,7 @@ import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent
import TextToImageTab from './tabs/TextToImage/TextToImageTab'; import TextToImageTab from './tabs/TextToImage/TextToImageTab';
import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab'; import UnifiedCanvasTab from './tabs/UnifiedCanvas/UnifiedCanvasTab';
import NodesTab from './tabs/Nodes/NodesTab'; import NodesTab from './tabs/Nodes/NodesTab';
import { FaImage } from 'react-icons/fa'; import { FaFont, FaImage } from 'react-icons/fa';
import ResizeHandle from './tabs/ResizeHandle'; import ResizeHandle from './tabs/ResizeHandle';
import ImageTab from './tabs/ImageToImage/ImageToImageTab'; import ImageTab from './tabs/ImageToImage/ImageToImageTab';
import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator'; import AuxiliaryProgressIndicator from 'app/components/AuxiliaryProgressIndicator';
@ -47,7 +46,7 @@ export interface InvokeTabInfo {
const tabs: InvokeTabInfo[] = [ const tabs: InvokeTabInfo[] = [
{ {
id: 'txt2img', id: 'txt2img',
icon: <Icon as={GoTextSize} sx={{ boxSize: 6, pointerEvents: 'none' }} />, icon: <Icon as={FaFont} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <TextToImageTab />, content: <TextToImageTab />,
}, },
{ {

View File

@ -8,6 +8,7 @@ import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Sym
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const ImageToImageTabParameters = () => { const ImageToImageTabParameters = () => {
return ( return (
@ -16,6 +17,7 @@ const ImageToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<ImageToImageTabCoreParameters /> <ImageToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />

View File

@ -9,6 +9,7 @@ import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import TextToImageTabCoreParameters from './TextToImageTabCoreParameters'; import TextToImageTabCoreParameters from './TextToImageTabCoreParameters';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const TextToImageTabParameters = () => { const TextToImageTabParameters = () => {
return ( return (
@ -17,6 +18,7 @@ const TextToImageTabParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<TextToImageTabCoreParameters /> <TextToImageTabCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamNoiseCollapse /> <ParamNoiseCollapse />

View File

@ -8,6 +8,7 @@ import { memo } from 'react';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
const UnifiedCanvasParameters = () => { const UnifiedCanvasParameters = () => {
return ( return (
@ -16,6 +17,7 @@ const UnifiedCanvasParameters = () => {
<ParamNegativeConditioning /> <ParamNegativeConditioning />
<ProcessButtons /> <ProcessButtons />
<UnifiedCanvasCoreParameters /> <UnifiedCanvasCoreParameters />
<ParamDynamicPromptsCollapse />
<ParamControlNetCollapse /> <ParamControlNetCollapse />
<ParamVariationCollapse /> <ParamVariationCollapse />
<ParamSymmetryCollapse /> <ParamSymmetryCollapse />

View File

@ -15,6 +15,7 @@ export const imagesApi = api.injectEndpoints({
} }
return tags; return tags;
}, },
keepUnusedDataFor: 86400, // 24 hours
}), }),
}), }),
}); });

View File

@ -0,0 +1,36 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Applies HED edge detection to image
*/
export type HedImageProcessorInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Whether or not this node is an intermediate node.
*/
is_intermediate?: boolean;
type?: 'hed_image_processor';
/**
* The image to process
*/
image?: ImageField;
/**
* The pixel resolution for detection
*/
detect_resolution?: number;
/**
* The pixel resolution for the output image
*/
image_resolution?: number;
/**
* Whether to use scribble mode
*/
scribble?: boolean;
};

View File

@ -648,6 +648,13 @@ export type components = {
* @default 1 * @default 1
*/ */
end_step_percent: number; end_step_percent: number;
/**
* Control Mode
* @description The contorl mode to use
* @default balanced
* @enum {string}
*/
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
}; };
/** /**
* ControlNetInvocation * ControlNetInvocation
@ -701,6 +708,13 @@ export type components = {
* @default 1 * @default 1
*/ */
end_step_percent?: number; end_step_percent?: number;
/**
* Control Mode
* @description The control mode used
* @default balanced
* @enum {string}
*/
control_mode?: "balanced" | "more_prompt" | "more_control" | "unbalanced";
}; };
/** ControlNetModelConfig */ /** ControlNetModelConfig */
ControlNetModelConfig: { ControlNetModelConfig: {
@ -1016,7 +1030,7 @@ export type components = {
* @description The nodes in this graph * @description The nodes in this graph
*/ */
nodes?: { nodes?: {
[key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined; [key: string]: (components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]) | undefined;
}; };
/** /**
* Edges * Edges
@ -1059,7 +1073,7 @@ export type components = {
* @description The results of node executions * @description The results of node executions
*/ */
results: { results: {
[key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined; [key: string]: (components["schemas"]["ImageOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["PromptOutput"] | components["schemas"]["PromptCollectionOutput"] | components["schemas"]["CompelOutput"] | components["schemas"]["IntOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["IntCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["CollectInvocationOutput"]) | undefined;
}; };
/** /**
* Errors * Errors
@ -2903,7 +2917,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -2979,6 +2993,18 @@ export type components = {
* @default 512 * @default 512
*/ */
height?: number; height?: number;
/**
* Perlin
* @description The amount of perlin noise to add to the noise
* @default 0
*/
perlin?: number;
/**
* Use Cpu
* @description Use CPU for noise generation (for reproducible results across platforms)
* @default true
*/
use_cpu?: boolean;
}; };
/** /**
* NoiseOutput * NoiseOutput
@ -4163,18 +4189,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion2ModelFormat * StableDiffusion2ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;
@ -4285,7 +4311,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {
@ -4322,7 +4348,7 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"]; "application/json": components["schemas"]["LoadImageInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageProcessorInvocation"] | components["schemas"]["PipelineModelLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ParamIntInvocation"] | components["schemas"]["ParamFloatInvocation"] | components["schemas"]["TextToLatentsInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["UpscaleInvocation"] | components["schemas"]["RestoreFaceInvocation"] | components["schemas"]["InpaintInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["LatentsToLatentsInvocation"];
}; };
}; };
responses: { responses: {

View File

@ -1,20 +1,45 @@
import SwaggerParser from '@apidevtools/swagger-parser';
import { createAsyncThunk } from '@reduxjs/toolkit'; import { createAsyncThunk } from '@reduxjs/toolkit';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { parsedOpenAPISchema } from 'features/nodes/store/nodesSlice';
import { OpenAPIV3 } from 'openapi-types'; import { OpenAPIV3 } from 'openapi-types';
const schemaLog = log.child({ namespace: 'schema' }); const schemaLog = log.child({ namespace: 'schema' });
function getCircularReplacer() {
const ancestors: Record<string, any>[] = [];
return function (key: string, value: any) {
if (typeof value !== 'object' || value === null) {
return value;
}
// `this` is the object that value is contained in,
// i.e., its direct parent.
// @ts-ignore
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
ancestors.pop();
}
if (ancestors.includes(value)) {
return '[Circular]';
}
ancestors.push(value);
return value;
};
}
export const receivedOpenAPISchema = createAsyncThunk( export const receivedOpenAPISchema = createAsyncThunk(
'nodes/receivedOpenAPISchema', 'nodes/receivedOpenAPISchema',
async (_, { dispatch }): Promise<OpenAPIV3.Document> => { async (_, { dispatch, rejectWithValue }) => {
const response = await fetch(`openapi.json`); try {
const openAPISchema = await response.json(); const dereferencedSchema = (await SwaggerParser.dereference(
'openapi.json'
)) as OpenAPIV3.Document;
schemaLog.info({ openAPISchema }, 'Received OpenAPI schema'); const schemaJSON = JSON.parse(
JSON.stringify(dereferencedSchema, getCircularReplacer())
);
dispatch(parsedOpenAPISchema(openAPISchema as OpenAPIV3.Document)); return schemaJSON;
} catch (error) {
return openAPISchema; return rejectWithValue({ error });
}
} }
); );

View File

@ -1,81 +1,92 @@
import { O } from 'ts-toolbelt';
import { components } from './schema'; import { components } from './schema';
type schemas = components['schemas'];
/** /**
* Types from the API, re-exported from the types generated by `openapi-typescript`. * Extracts the schema type from the schema.
*/ */
type S<T extends keyof components['schemas']> = components['schemas'][T];
/**
* Extracts the node type from the schema.
* Also flags the `type` property as required.
*/
type N<T extends keyof components['schemas']> = O.Required<
components['schemas'][T],
'type'
>;
// Images // Images
export type ImageDTO = components['schemas']['ImageDTO']; export type ImageDTO = S<'ImageDTO'>;
export type BoardDTO = components['schemas']['BoardDTO']; export type BoardDTO = S<'BoardDTO'>;
export type BoardChanges = components['schemas']['BoardChanges']; export type BoardChanges = S<'BoardChanges'>;
export type ImageChanges = components['schemas']['ImageRecordChanges']; export type ImageChanges = S<'ImageRecordChanges'>;
export type ImageCategory = components['schemas']['ImageCategory']; export type ImageCategory = S<'ImageCategory'>;
export type ResourceOrigin = components['schemas']['ResourceOrigin']; export type ResourceOrigin = S<'ResourceOrigin'>;
export type ImageField = components['schemas']['ImageField']; export type ImageField = S<'ImageField'>;
export type OffsetPaginatedResults_BoardDTO_ = export type OffsetPaginatedResults_BoardDTO_ =
components['schemas']['OffsetPaginatedResults_BoardDTO_']; S<'OffsetPaginatedResults_BoardDTO_'>;
export type OffsetPaginatedResults_ImageDTO_ = export type OffsetPaginatedResults_ImageDTO_ =
components['schemas']['OffsetPaginatedResults_ImageDTO_']; S<'OffsetPaginatedResults_ImageDTO_'>;
// Models // Models
export type ModelType = components['schemas']['ModelType']; export type ModelType = S<'ModelType'>;
export type BaseModelType = components['schemas']['BaseModelType']; export type BaseModelType = S<'BaseModelType'>;
export type PipelineModelField = components['schemas']['PipelineModelField']; export type PipelineModelField = S<'PipelineModelField'>;
export type ModelsList = components['schemas']['ModelsList']; export type ModelsList = S<'ModelsList'>;
// Graphs // Graphs
export type Graph = components['schemas']['Graph']; export type Graph = S<'Graph'>;
export type Edge = components['schemas']['Edge']; export type Edge = S<'Edge'>;
export type GraphExecutionState = components['schemas']['GraphExecutionState']; export type GraphExecutionState = S<'GraphExecutionState'>;
// General nodes // General nodes
export type CollectInvocation = components['schemas']['CollectInvocation']; export type CollectInvocation = N<'CollectInvocation'>;
export type IterateInvocation = components['schemas']['IterateInvocation']; export type IterateInvocation = N<'IterateInvocation'>;
export type RangeInvocation = components['schemas']['RangeInvocation']; export type RangeInvocation = N<'RangeInvocation'>;
export type RandomRangeInvocation = export type RandomRangeInvocation = N<'RandomRangeInvocation'>;
components['schemas']['RandomRangeInvocation']; export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>;
export type RangeOfSizeInvocation = export type InpaintInvocation = N<'InpaintInvocation'>;
components['schemas']['RangeOfSizeInvocation']; export type ImageResizeInvocation = N<'ImageResizeInvocation'>;
export type InpaintInvocation = components['schemas']['InpaintInvocation']; export type RandomIntInvocation = N<'RandomIntInvocation'>;
export type ImageResizeInvocation = export type CompelInvocation = N<'CompelInvocation'>;
components['schemas']['ImageResizeInvocation']; export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>;
export type RandomIntInvocation = components['schemas']['RandomIntInvocation']; export type NoiseInvocation = N<'NoiseInvocation'>;
export type CompelInvocation = components['schemas']['CompelInvocation']; export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>;
export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>;
export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>;
export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>;
export type PipelineModelLoaderInvocation = N<'PipelineModelLoaderInvocation'>;
// ControlNet Nodes // ControlNet Nodes
export type CannyImageProcessorInvocation = export type ControlNetInvocation = N<'ControlNetInvocation'>;
components['schemas']['CannyImageProcessorInvocation']; export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>;
export type ContentShuffleImageProcessorInvocation = export type ContentShuffleImageProcessorInvocation =
components['schemas']['ContentShuffleImageProcessorInvocation']; N<'ContentShuffleImageProcessorInvocation'>;
export type HedImageProcessorInvocation = export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>;
components['schemas']['HedImageProcessorInvocation'];
export type LineartAnimeImageProcessorInvocation = export type LineartAnimeImageProcessorInvocation =
components['schemas']['LineartAnimeImageProcessorInvocation']; N<'LineartAnimeImageProcessorInvocation'>;
export type LineartImageProcessorInvocation = export type LineartImageProcessorInvocation =
components['schemas']['LineartImageProcessorInvocation']; N<'LineartImageProcessorInvocation'>;
export type MediapipeFaceProcessorInvocation = export type MediapipeFaceProcessorInvocation =
components['schemas']['MediapipeFaceProcessorInvocation']; N<'MediapipeFaceProcessorInvocation'>;
export type MidasDepthImageProcessorInvocation = export type MidasDepthImageProcessorInvocation =
components['schemas']['MidasDepthImageProcessorInvocation']; N<'MidasDepthImageProcessorInvocation'>;
export type MlsdImageProcessorInvocation = export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>;
components['schemas']['MlsdImageProcessorInvocation'];
export type NormalbaeImageProcessorInvocation = export type NormalbaeImageProcessorInvocation =
components['schemas']['NormalbaeImageProcessorInvocation']; N<'NormalbaeImageProcessorInvocation'>;
export type OpenposeImageProcessorInvocation = export type OpenposeImageProcessorInvocation =
components['schemas']['OpenposeImageProcessorInvocation']; N<'OpenposeImageProcessorInvocation'>;
export type PidiImageProcessorInvocation = export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>;
components['schemas']['PidiImageProcessorInvocation'];
export type ZoeDepthImageProcessorInvocation = export type ZoeDepthImageProcessorInvocation =
components['schemas']['ZoeDepthImageProcessorInvocation']; N<'ZoeDepthImageProcessorInvocation'>;
// Node Outputs // Node Outputs
export type ImageOutput = components['schemas']['ImageOutput']; export type ImageOutput = S<'ImageOutput'>;
export type MaskOutput = components['schemas']['MaskOutput']; export type MaskOutput = S<'MaskOutput'>;
export type PromptOutput = components['schemas']['PromptOutput']; export type PromptOutput = S<'PromptOutput'>;
export type IterateInvocationOutput = export type IterateInvocationOutput = S<'IterateInvocationOutput'>;
components['schemas']['IterateInvocationOutput']; export type CollectInvocationOutput = S<'CollectInvocationOutput'>;
export type CollectInvocationOutput = export type LatentsOutput = S<'LatentsOutput'>;
components['schemas']['CollectInvocationOutput']; export type GraphInvocationOutput = S<'GraphInvocationOutput'>;
export type LatentsOutput = components['schemas']['LatentsOutput'];
export type GraphInvocationOutput =
components['schemas']['GraphInvocationOutput'];

View File

@ -9,6 +9,7 @@
"vite.config.ts", "vite.config.ts",
"./config/vite.app.config.ts", "./config/vite.app.config.ts",
"./config/vite.package.config.ts", "./config/vite.package.config.ts",
"./config/vite.common.config.ts" "./config/vite.common.config.ts",
"./config/common.ts"
] ]
} }

File diff suppressed because it is too large Load Diff

View File

@ -39,7 +39,7 @@ dependencies = [
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel>=1.2.1", "compel>=1.2.1",
"controlnet-aux>=0.0.4", "controlnet-aux>=0.0.6",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"datasets", "datasets",
"diffusers[torch]~=0.17.1", "diffusers[torch]~=0.17.1",

30
tests/conftest.py Normal file
View File

@ -0,0 +1,30 @@
import pytest
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.graph import LibraryGraph, GraphExecutionState
from invokeai.app.services.processor import DefaultInvocationProcessor
# Ignore these files as they need to be rewritten following the model manager refactor
collect_ignore = ["nodes/test_graph_execution_state.py", "nodes/test_node_graph.py", "test_textual_inversion.py"]
@pytest.fixture(scope="session", autouse=True)
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = None, # type: ignore
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
board_images=None, # type: ignore
boards=None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)

View File

@ -1,14 +1,18 @@
from .test_invoker import create_edge import pytest
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import (BaseInvocation,
BaseInvocationOutput,
InvocationContext)
from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.processor import DefaultInvocationProcessor from invokeai.app.services.graph import (CollectInvocation, Graph,
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory GraphExecutionState,
from invokeai.app.services.invocation_queue import MemoryInvocationQueue IterateInvocation)
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest from .test_invoker import create_edge
from .test_nodes import (ImageTestInvocation, PromptCollectionTestInvocation,
PromptTestInvocation)
@pytest.fixture @pytest.fixture
@ -19,25 +23,6 @@ def simple_graph():
g.add_edge(create_edge("1", "prompt", "2", "prompt")) g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g return g
@pytest.fixture
def mock_services():
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = None, # type: ignore
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
n = g.next() n = g.next()
if n is None: if n is None:

View File

@ -1,13 +1,12 @@
from .test_nodes import ErrorInvocation, ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation, TestEventService, create_edge, wait_until
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invoker import Invoker
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
import pytest import pytest
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invoker import Invoker
from .test_nodes import (ErrorInvocation, ImageTestInvocation,
PromptTestInvocation, create_edge, wait_until)
@pytest.fixture @pytest.fixture
def simple_graph(): def simple_graph():
@ -17,25 +16,6 @@ def simple_graph():
g.add_edge(create_edge("1", "prompt", "2", "prompt")) g.add_edge(create_edge("1", "prompt", "2", "prompt"))
return g return g
@pytest.fixture
def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
model_manager = None, # type: ignore
events = TestEventService(),
logger = None, # type: ignore
images = None, # type: ignore
latents = None, # type: ignore
queue = MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=sqlite_memory, table_name="graphs"
),
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
processor = DefaultInvocationProcessor(),
restoration = None, # type: ignore
configuration = None, # type: ignore
)
@pytest.fixture() @pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker: def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker( return Invoker(
@ -57,6 +37,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
assert isinstance(g, GraphExecutionState) assert isinstance(g, GraphExecutionState)
assert g.graph == simple_graph assert g.graph == simple_graph
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph): def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph) g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g) invocation_id = mock_invoker.invoke(g)
@ -72,6 +53,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0 assert len(g.executed) > 0
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke_all(mock_invoker: Invoker, simple_graph): def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph = simple_graph) g = mock_invoker.create_execution_state(graph = simple_graph)
invocation_id = mock_invoker.invoke(g, invoke_all = True) invocation_id = mock_invoker.invoke(g, invoke_all = True)
@ -87,6 +69,7 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.services.graph_execution_manager.get(g.id) g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete() assert g.is_complete()
@pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_handles_errors(mock_invoker: Invoker): def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state() g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id = "1")) g.graph.add_node(ErrorInvocation(id = "1"))