merge with recent main changes

This commit is contained in:
Lincoln Stein 2023-07-26 06:39:21 -04:00
commit a705461c04
114 changed files with 3523 additions and 506 deletions

View File

@ -1,11 +1,11 @@
name: Close inactive issues name: Close inactive issues
on: on:
schedule: schedule:
- cron: "00 6 * * *" - cron: "00 4 * * *"
env: env:
DAYS_BEFORE_ISSUE_STALE: 14 DAYS_BEFORE_ISSUE_STALE: 30
DAYS_BEFORE_ISSUE_CLOSE: 28 DAYS_BEFORE_ISSUE_CLOSE: 14
jobs: jobs:
close-issues: close-issues:
@ -14,7 +14,7 @@ jobs:
issues: write issues: write
pull-requests: write pull-requests: write
steps: steps:
- uses: actions/stale@v5 - uses: actions/stale@v8
with: with:
days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }} days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }}
days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }} days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }}
@ -23,5 +23,6 @@ jobs:
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue." close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
days-before-pr-stale: -1 days-before-pr-stale: -1
days-before-pr-close: -1 days-before-pr-close: -1
exempt-issue-labels: "Active Issue"
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}
operations-per-run: 500 operations-per-run: 500

Binary file not shown.

After

Width:  |  Height:  |  Size: 490 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 335 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 217 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 244 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 948 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 292 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 420 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 179 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 216 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 439 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 563 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 353 KiB

View File

@ -118,49 +118,49 @@ There are several node grouping concepts that can be examined with a narrow focu
As described, an initial noise tensor is necessary for the latent diffusion process. As a result, all non-image *ToLatents nodes require a noise node input. As described, an initial noise tensor is necessary for the latent diffusion process. As a result, all non-image *ToLatents nodes require a noise node input.
<img width="654" alt="groupsnoise" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/2e8d297e-ad55-4d27-bc93-c119dad2a2c5"> ![groupsnoise](../assets/nodes/groupsnoise.png)
### Conditioning ### Conditioning
As described, conditioning is necessary for the latent diffusion process, whether empty or not. As a result, all non-image *ToLatents nodes require positive and negative conditioning inputs. Conditioning is reliant on a CLIP tokenizer provided by the Model Loader node. As described, conditioning is necessary for the latent diffusion process, whether empty or not. As a result, all non-image *ToLatents nodes require positive and negative conditioning inputs. Conditioning is reliant on a CLIP tokenizer provided by the Model Loader node.
<img width="1024" alt="groupsconditioning" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/f8f7ad8a-8d9c-418e-b5ad-1437b774b27e"> ![groupsconditioning](../assets/nodes/groupsconditioning.png)
### Image Space & VAE ### Image Space & VAE
The ImageToLatents node doesn't require a noise node input, but requires a VAE input to convert the image from image space into latent space. In reverse, the LatentsToImage node requires a VAE input to convert from latent space back into image space. The ImageToLatents node doesn't require a noise node input, but requires a VAE input to convert the image from image space into latent space. In reverse, the LatentsToImage node requires a VAE input to convert from latent space back into image space.
<img width="637" alt="groupsimgvae" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/dd99969c-e0a8-4f78-9b17-3ffe179cef9a"> ![groupsimgvae](../assets/nodes/groupsimgvae.png)
### Defined & Random Seeds ### Defined & Random Seeds
It is common to want to use both the same seed (for continuity) and random seeds (for variance). To define a seed, simply enter it into the 'Seed' field on a noise node. Conversely, the RandomInt node generates a random integer between 'Low' and 'High', and can be used as input to the 'Seed' edge point on a noise node to randomize your seed. It is common to want to use both the same seed (for continuity) and random seeds (for variance). To define a seed, simply enter it into the 'Seed' field on a noise node. Conversely, the RandomInt node generates a random integer between 'Low' and 'High', and can be used as input to the 'Seed' edge point on a noise node to randomize your seed.
<img width="922" alt="groupsrandseed" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/af55bc20-60f6-438e-aba5-3ec871443710"> ![groupsrandseed](../assets/nodes/groupsrandseed.png)
### Control ### Control
Control means to guide the diffusion process to adhere to a defined input or structure. Control can be provided as input to non-image *ToLatents nodes from ControlNet nodes. ControlNet nodes usually require an image processor which converts an input image for use with ControlNet. Control means to guide the diffusion process to adhere to a defined input or structure. Control can be provided as input to non-image *ToLatents nodes from ControlNet nodes. ControlNet nodes usually require an image processor which converts an input image for use with ControlNet.
<img width="805" alt="groupscontrol" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/cc9c5de7-23a7-46c8-bbad-1f3609d999a6"> ![groupscontrol](../assets/nodes/groupscontrol.png)
### LoRA ### LoRA
The Lora Loader node lets you load a LoRA (say that ten times fast) and pass it as output to both the Prompt (Compel) and non-image *ToLatents nodes. A model's CLIP tokenizer is passed through the LoRA into Prompt (Compel), where it affects conditioning. A model's U-Net is also passed through the LoRA into a non-image *ToLatents node, where it affects noise prediction. The Lora Loader node lets you load a LoRA (say that ten times fast) and pass it as output to both the Prompt (Compel) and non-image *ToLatents nodes. A model's CLIP tokenizer is passed through the LoRA into Prompt (Compel), where it affects conditioning. A model's U-Net is also passed through the LoRA into a non-image *ToLatents node, where it affects noise prediction.
<img width="993" alt="groupslora" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/630962b0-d914-4505-b3ea-ccae9b0269da"> ![groupslora](../assets/nodes/groupslora.png)
### Scaling ### Scaling
Use the ImageScale, ScaleLatents, and Upscale nodes to upscale images and/or latent images. The chosen method differs across contexts. However, be aware that latents are already noisy and compressed at their original resolution; scaling an image could produce more detailed results. Use the ImageScale, ScaleLatents, and Upscale nodes to upscale images and/or latent images. The chosen method differs across contexts. However, be aware that latents are already noisy and compressed at their original resolution; scaling an image could produce more detailed results.
<img width="644" alt="groupsallscale" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/99314f05-dd9f-4b6d-b378-31de55346a13"> ![groupsallscale](../assets/nodes/groupsallscale.png)
### Iteration + Multiple Images as Input ### Iteration + Multiple Images as Input
Iteration is a common concept in any processing, and means to repeat a process with given input. In nodes, you're able to use the Iterate node to iterate through collections usually gathered by the Collect node. The Iterate node has many potential uses, from processing a collection of images one after another, to varying seeds across multiple image generations and more. This screenshot demonstrates how to collect several images and pass them out one at a time. Iteration is a common concept in any processing, and means to repeat a process with given input. In nodes, you're able to use the Iterate node to iterate through collections usually gathered by the Collect node. The Iterate node has many potential uses, from processing a collection of images one after another, to varying seeds across multiple image generations and more. This screenshot demonstrates how to collect several images and pass them out one at a time.
<img width="788" alt="groupsiterate" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/4af5ca27-82c9-4018-8c5b-024d3ee0a121"> ![groupsiterate](../assets/nodes/groupsiterate.png)
### Multiple Image Generation + Random Seeds ### Multiple Image Generation + Random Seeds
@ -168,7 +168,7 @@ Multiple image generation in the node editor is done using the RandomRange node.
To control seeds across generations takes some care. The first row in the screenshot will generate multiple images with different seeds, but using the same RandomRange parameters across invocations will result in the same group of random seeds being used across the images, producing repeatable results. In the second row, adding the RandomInt node as input to RandomRange's 'Seed' edge point will ensure that seeds are varied across all images across invocations, producing varied results. To control seeds across generations takes some care. The first row in the screenshot will generate multiple images with different seeds, but using the same RandomRange parameters across invocations will result in the same group of random seeds being used across the images, producing repeatable results. In the second row, adding the RandomInt node as input to RandomRange's 'Seed' edge point will ensure that seeds are varied across all images across invocations, producing varied results.
<img width="1027" alt="groupsmultigenseeding" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/518d1b2b-fed1-416b-a052-ab06552521b3"> ![groupsmultigenseeding](../assets/nodes/groupsmultigenseeding.png)
## Examples ## Examples
@ -176,7 +176,7 @@ With our knowledge of node grouping and the diffusion process, lets break dow
### Basic text-to-image Node Graph ### Basic text-to-image Node Graph
<img width="875" alt="nodest2i" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/17c67720-c376-4db8-94f0-5e00381a61ee"> ![nodest2i](../assets/nodes/nodest2i.png)
- Model Loader: A necessity to generating images (as weve read above). We choose our model from the dropdown. It outputs a U-Net, CLIP tokenizer, and VAE. - Model Loader: A necessity to generating images (as weve read above). We choose our model from the dropdown. It outputs a U-Net, CLIP tokenizer, and VAE.
- Prompt (Compel): Another necessity. Two prompt nodes are created. One will output positive conditioning (what you want, dog), one will output negative (what you dont want, cat). They both input the CLIP tokenizer that the Model Loader node outputs. - Prompt (Compel): Another necessity. Two prompt nodes are created. One will output positive conditioning (what you want, dog), one will output negative (what you dont want, cat). They both input the CLIP tokenizer that the Model Loader node outputs.
@ -186,7 +186,7 @@ With our knowledge of node grouping and the diffusion process, lets break dow
### Basic image-to-image Node Graph ### Basic image-to-image Node Graph
<img width="998" alt="nodesi2i" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/3f2c95d5-cee7-4415-9b79-b46ee60a92fe"> ![nodesi2i](../assets/nodes/nodesi2i.png)
- Model Loader: Choose a model from the dropdown. - Model Loader: Choose a model from the dropdown.
- Prompt (Compel): Two prompt nodes. One positive (dog), one negative (dog). Same CLIP inputs from the Model Loader node as before. - Prompt (Compel): Two prompt nodes. One positive (dog), one negative (dog). Same CLIP inputs from the Model Loader node as before.
@ -197,7 +197,7 @@ With our knowledge of node grouping and the diffusion process, lets break dow
### Basic ControlNet Node Graph ### Basic ControlNet Node Graph
<img width="703" alt="nodescontrol" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/b02ded86-ceb4-44a2-9910-e19ad184d471"> ![nodescontrol](../assets/nodes/nodescontrol.png)
- Model Loader - Model Loader
- Prompt (Compel) - Prompt (Compel)

View File

@ -298,7 +298,7 @@ async def search_for_models(
)->List[pathlib.Path]: )->List[pathlib.Path]:
if not search_path.is_dir(): if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
@models_router.get( @models_router.get(
"/ckpt_confs", "/ckpt_confs",

View File

@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled): def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(), **clip_field.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -196,6 +196,7 @@ class SDXLPromptInvocationBase:
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:
@ -240,16 +241,16 @@ class SDXLPromptInvocationBase:
def run_clip_compel(self, context, clip_field, prompt, get_pooled): def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(), **clip_field.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -265,6 +266,7 @@ class SDXLPromptInvocationBase:
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:

View File

@ -2,15 +2,18 @@ from typing import Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (BaseInvocation, from invokeai.app.invocations.baseinvocation import (
BaseInvocationOutput, InvocationConfig, BaseInvocation,
InvocationContext) BaseInvocationOutput,
InvocationConfig,
InvocationContext,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import (LoRAModelField, MainModelField, from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
VAEModelField)
class LoRAMetadataField(BaseModel): class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI.""" """LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model") lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model") weight: float = Field(description="The weight of the LoRA model")
@ -18,7 +21,9 @@ class LoRAMetadataField(BaseModel):
class CoreMetadata(BaseModel): class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI.""" """Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(description="The generation mode that output this image",) generation_mode: str = Field(
description="The generation mode that output this image",
)
positive_prompt: str = Field(description="The positive prompt parameter") positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter") negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter") width: int = Field(description="The width parameter")
@ -28,10 +33,20 @@ class CoreMetadata(BaseModel):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter") cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference") steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference") scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",) clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference") model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference") controlnets: list[ControlField] = Field(
description="The ControlNets used for inference"
)
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
# Latents-to-Latents
strength: Union[float, None] = Field( strength: Union[float, None] = Field(
default=None, default=None,
description="The strength used for latents-to-latents", description="The strength used for latents-to-latents",
@ -39,9 +54,34 @@ class CoreMetadata(BaseModel):
init_image: Union[str, None] = Field( init_image: Union[str, None] = Field(
default=None, description="The name of the initial image" default=None, description="The name of the initial image"
) )
vae: Union[VAEModelField, None] = Field(
# SDXL
positive_style_prompt: Union[str, None] = Field(
default=None, description="The positive style prompt parameter"
)
negative_style_prompt: Union[str, None] = Field(
default=None, description="The negative style prompt parameter"
)
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(
default=None, description="The SDXL Refiner model used"
)
refiner_cfg_scale: Union[float, None] = Field(
default=None, default=None,
description="The VAE used for decoding, if the main model's default was not used", description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(
default=None, description="The number of steps used for the refiner"
)
refiner_scheduler: Union[str, None] = Field(
default=None, description="The scheduler used for the refiner"
)
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(
default=None, description="The start value used for refiner denoising"
) )
@ -70,7 +110,9 @@ class MetadataAccumulatorInvocation(BaseInvocation):
type: Literal["metadata_accumulator"] = "metadata_accumulator" type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field(description="The generation mode that output this image",) generation_mode: str = Field(
description="The generation mode that output this image",
)
positive_prompt: str = Field(description="The positive prompt parameter") positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter") negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter") width: int = Field(description="The width parameter")
@ -80,9 +122,13 @@ class MetadataAccumulatorInvocation(BaseInvocation):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter") cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference") steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference") scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",) clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference") model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference") controlnets: list[ControlField] = Field(
description="The ControlNets used for inference"
)
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field( strength: Union[float, None] = Field(
default=None, default=None,
@ -96,36 +142,44 @@ class MetadataAccumulatorInvocation(BaseInvocation):
description="The VAE used for decoding, if the main model's default was not used", description="The VAE used for decoding, if the main model's default was not used",
) )
# SDXL
positive_style_prompt: Union[str, None] = Field(
default=None, description="The positive style prompt parameter"
)
negative_style_prompt: Union[str, None] = Field(
default=None, description="The negative style prompt parameter"
)
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(
default=None, description="The SDXL Refiner model used"
)
refiner_cfg_scale: Union[float, None] = Field(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(
default=None, description="The number of steps used for the refiner"
)
refiner_scheduler: Union[str, None] = Field(
default=None, description="The scheduler used for the refiner"
)
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(
default=None, description="The start value used for refiner denoising"
)
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"title": "Metadata Accumulator", "title": "Metadata Accumulator",
"tags": ["image", "metadata", "generation"] "tags": ["image", "metadata", "generation"],
}, },
} }
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput: def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object""" """Collects and outputs a CoreMetadata object"""
return MetadataAccumulatorOutput( return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
metadata=CoreMetadata(
generation_mode=self.generation_mode,
positive_prompt=self.positive_prompt,
negative_prompt=self.negative_prompt,
width=self.width,
height=self.height,
seed=self.seed,
rand_device=self.rand_device,
cfg_scale=self.cfg_scale,
steps=self.steps,
scheduler=self.scheduler,
model=self.model,
strength=self.strength,
init_image=self.init_image,
vae=self.vae,
controlnets=self.controlnets,
loras=self.loras,
clip_skip=self.clip_skip,
)
)

View File

@ -119,8 +119,8 @@ class NoiseInvocation(BaseInvocation):
@validator("seed", pre=True) @validator("seed", pre=True)
def modulo_seed(cls, v): def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range.""" """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % SEED_MAX return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> NoiseOutput: def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise( noise = get_noise(

View File

@ -138,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"title": "SDXL Refiner Model Loader", "title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"], "tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "model"}, "type_hints": {"model": "refiner_model"},
}, },
} }
@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context
) )
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None
@ -463,8 +463,8 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
latents: Optional[LatentsField] = Field(description="Initial latents") latents: Optional[LatentsField] = Field(description="Initial latents")
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="") denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="") denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
@ -549,13 +549,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
num_inference_steps = num_inference_steps - t_start num_inference_steps = num_inference_steps - t_start
# apply noise(if provided) # apply noise(if provided)
if self.noise is not None: if self.noise is not None and timesteps.shape[0] > 0:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1]) latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise del noise
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None

View File

@ -3,7 +3,13 @@
from typing import Any, Optional from typing import Any, Optional
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo from invokeai.app.services.model_manager_service import (
BaseModelType,
ModelType,
SubModelType,
ModelInfo,
)
class EventServiceBase: class EventServiceBase:
session_event: str = "session_event" session_event: str = "session_event"
@ -38,7 +44,9 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None, progress_image=progress_image.dict()
if progress_image is not None
else None,
step=step, step=step,
total_steps=total_steps, total_steps=total_steps,
), ),
@ -67,6 +75,7 @@ class EventServiceBase:
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict, node: dict,
source_node_id: str, source_node_id: str,
error_type: str,
error: str, error: str,
) -> None: ) -> None:
"""Emitted when an invocation has completed""" """Emitted when an invocation has completed"""
@ -76,6 +85,7 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
error_type=error_type,
error=error, error=error,
), ),
) )
@ -102,7 +112,7 @@ class EventServiceBase:
), ),
) )
def emit_model_load_started ( def emit_model_load_started(
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
model_name: str, model_name: str,
@ -145,3 +155,37 @@ class EventServiceBase:
precision=str(model_info.precision), precision=str(model_info.precision),
), ),
) )
def emit_session_retrieval_error(
self,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_session_event(
event_name="session_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
),
)
def emit_invocation_retrieval_error(
self,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_session_event(
event_name="invocation_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
)

View File

@ -600,7 +600,7 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
Return list of all models found in the designated directory. Return list of all models found in the designated directory.
""" """
search = FindModels(directory,self.logger) search = FindModels([directory], self.logger)
return search.list_models() return search.list_models()
def sync_to_config(self): def sync_to_config(self):

View File

@ -39,21 +39,41 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try: try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get() queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e: except Exception as e:
logger.debug("Exception while getting from queue: %s" % e) self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
if not queue_item: # Probably stopping if not queue_item: # Probably stopping
# do not hammer the queue # do not hammer the queue
time.sleep(0.5) time.sleep(0.5)
continue continue
try:
graph_execution_state = ( graph_execution_state = (
self.__invoker.services.graph_execution_manager.get( self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id queue_item.graph_execution_state_id
) )
) )
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
try:
invocation = graph_execution_state.execution_graph.get_node( invocation = graph_execution_state.execution_graph.get_node(
queue_item.invocation_id queue_item.invocation_id
) )
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
# get the source node id to provide to clients (the prepared node id is not as useful) # get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
@ -114,11 +134,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state graph_execution_state
) )
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event # Send error event
self.__invoker.services.events.emit_invocation_error( self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(), node=invocation.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=error, error=error,
) )
@ -136,11 +158,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try: try:
self.__invoker.invoke(graph_execution_state, invoke_all=True) self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e: except Exception as e:
logger.error("Error while invoking: %s" % e) self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error( self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(), node=invocation.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=traceback.format_exc() error=traceback.format_exc()
) )
elif is_complete: elif is_complete:

View File

@ -14,8 +14,9 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(iso_timestamp) return datetime.datetime.fromisoformat(iso_timestamp)
SEED_MAX = np.iinfo(np.int32).max SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed(): def get_random_seed():
return np.random.randint(0, SEED_MAX) rng = np.random.default_rng(seed=0)
return int(rng.integers(0, SEED_MAX))

View File

@ -474,7 +474,7 @@ class ModelPatcher:
@staticmethod @staticmethod
def _lora_forward_hook( def _lora_forward_hook(
applied_loras: List[Tuple[LoraModel, float]], applied_loras: List[Tuple[LoRAModel, float]],
layer_name: str, layer_name: str,
): ):
@ -519,7 +519,7 @@ class ModelPatcher:
def apply_lora( def apply_lora(
cls, cls,
model: torch.nn.Module, model: torch.nn.Module,
loras: List[Tuple[LoraModel, float]], loras: List[Tuple[LoRAModel, float]],
prefix: str, prefix: str,
): ):
original_weights = dict() original_weights = dict()

View File

@ -98,6 +98,6 @@ class FindModels(ModelSearch):
def list_models(self) -> List[Path]: def list_models(self) -> List[Path]:
self.search() self.search()
return self.models_found return list(self.models_found)

View File

@ -10,6 +10,7 @@ from .base import (
SubModelType, SubModelType,
classproperty, classproperty,
InvalidModelException, InvalidModelException,
ModelNotFoundException,
) )
# TODO: naming # TODO: naming
from ..lora import LoRAModel as LoRAModelRaw from ..lora import LoRAModel as LoRAModelRaw

View File

@ -1,2 +1,2 @@
export const NUMPY_RAND_MIN = 0; export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 2147483647; export const NUMPY_RAND_MAX = 4294967295;

View File

@ -65,11 +65,14 @@ import { addGeneratorProgressEventListener as addGeneratorProgressListener } fro
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete'; import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete'; import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError'; import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted'; import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad'; import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested'; import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
@ -153,6 +156,8 @@ addSocketDisconnectedListener();
addSocketSubscribedListener(); addSocketSubscribedListener();
addSocketUnsubscribedListener(); addSocketUnsubscribedListener();
addModelLoadEventListener(); addModelLoadEventListener();
addSessionRetrievalErrorEventListener();
addInvocationRetrievalErrorEventListener();
// Session Created // Session Created
addSessionCreatedPendingListener(); addSessionCreatedPendingListener();
@ -197,3 +202,6 @@ addFirstListImagesListener();
// Ad-hoc upscale workflwo // Ad-hoc upscale workflwo
addUpscaleRequestedListener(); addUpscaleRequestedListener();
// Tab Change
addTabChangedListener();

View File

@ -1,4 +1,8 @@
import { setInfillMethod } from 'features/parameters/store/generationSlice'; import { setInfillMethod } from 'features/parameters/store/generationSlice';
import {
shouldUseNSFWCheckerChanged,
shouldUseWatermarkerChanged,
} from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo'; import { appInfoApi } from 'services/api/endpoints/appInfo';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -6,12 +10,21 @@ export const addAppConfigReceivedListener = () => {
startAppListening({ startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled, matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const { infill_methods } = action.payload; const { infill_methods, nsfw_methods, watermarking_methods } =
action.payload;
const infillMethod = getState().generation.infillMethod; const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) { if (!infill_methods.includes(infillMethod)) {
dispatch(setInfillMethod(infill_methods[0])); dispatch(setInfillMethod(infill_methods[0]));
} }
if (!nsfw_methods.includes('nsfw_checker')) {
dispatch(shouldUseNSFWCheckerChanged(false));
}
if (!watermarking_methods.includes('invisible_watermark')) {
dispatch(shouldUseWatermarkerChanged(false));
}
}, },
}); });
}; };

View File

@ -9,13 +9,19 @@ import {
zMainModel, zMainModel,
zVaeModel, zVaeModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import {
refinerModelChanged,
setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
export const addModelsLoadedListener = () => { export const addModelsLoadedListener = () => {
startAppListening({ startAppListening({
matcher: modelsApi.endpoints.getMainModels.matchFulfilled, predicate: (state, action) =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one // models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models'); const log = logger('models');
@ -59,6 +65,54 @@ export const addModelsLoadedListener = () => {
dispatch(modelChanged(result.data)); dispatch(modelChanged(result.data));
}, },
}); });
startAppListening({
predicate: (state, action) =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info(
{ models: action.payload.entities },
`SDXL Refiner models loaded (${action.payload.ids.length})`
);
const currentModel = getState().sdxl.refinerModel;
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No models loaded at all
dispatch(refinerModelChanged(null));
dispatch(setShouldUseSDXLRefiner(false));
return;
}
const result = zMainModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse SDXL Refiner Model'
);
return;
}
dispatch(refinerModelChanged(result.data));
},
});
startAppListening({ startAppListening({
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled, matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {

View File

@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => {
effect: (action) => { effect: (action) => {
const log = logger('session'); const log = logger('session');
if (action.payload) { if (action.payload) {
const { error } = action.payload; const { error, status } = action.payload;
const graph = parseify(action.meta.arg); const graph = parseify(action.meta.arg);
const stringifiedError = JSON.stringify(error);
log.error( log.error(
{ graph, error: serializeError(error) }, { graph, status, error: serializeError(error) },
`Problem creating session: ${stringifiedError}` `Problem creating session`
); );
} }
}, },

View File

@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => {
const { session_id } = action.meta.arg; const { session_id } = action.meta.arg;
if (action.payload) { if (action.payload) {
const { error } = action.payload; const { error } = action.payload;
const stringifiedError = JSON.stringify(error);
log.error( log.error(
{ {
session_id, session_id,
error: serializeError(error), error: serializeError(error),
}, },
`Problem invoking session: ${stringifiedError}` `Problem invoking session`
); );
} }
}, },

View File

@ -1,4 +1,6 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { LIST_TAG } from 'services/api';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { appSocketConnected, socketConnected } from 'services/events/actions';
@ -24,11 +26,18 @@ export const addSocketConnectedEventListener = () => {
dispatch(appSocketConnected(action.payload)); dispatch(appSocketConnected(action.payload));
// update all server state // update all server state
dispatch(modelsApi.endpoints.getMainModels.initiate()); dispatch(
dispatch(modelsApi.endpoints.getControlNetModels.initiate()); modelsApi.util.invalidateTags([
dispatch(modelsApi.endpoints.getLoRAModels.initiate()); { type: 'MainModel', id: LIST_TAG },
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate()); { type: 'SDXLRefinerModel', id: LIST_TAG },
dispatch(modelsApi.endpoints.getVaeModels.initiate()); { type: 'LoRAModel', id: LIST_TAG },
{ type: 'ControlNetModel', id: LIST_TAG },
{ type: 'VaeModel', id: LIST_TAG },
{ type: 'TextualInversionModel', id: LIST_TAG },
{ type: 'ScannedModels', id: LIST_TAG },
])
);
dispatch(appInfoApi.util.invalidateTags(['AppConfig', 'AppVersion']));
}, },
}); });
}; };

View File

@ -0,0 +1,20 @@
import { logger } from 'app/logging/logger';
import {
appSocketInvocationRetrievalError,
socketInvocationRetrievalError,
} from 'services/events/actions';
import { startAppListening } from '../..';
export const addInvocationRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketInvocationRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
log.error(
action.payload,
`Invocation retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketInvocationRetrievalError(action.payload));
},
});
};

View File

@ -21,7 +21,10 @@ export const addInvocationStartedEventListener = () => {
return; return;
} }
log.debug(action.payload, 'Invocation started'); log.debug(
action.payload,
`Invocation started (${action.payload.data.node.type})`
);
dispatch(appSocketInvocationStarted(action.payload)); dispatch(appSocketInvocationStarted(action.payload));
}, },
}); });

View File

@ -0,0 +1,20 @@
import { logger } from 'app/logging/logger';
import {
appSocketSessionRetrievalError,
socketSessionRetrievalError,
} from 'services/events/actions';
import { startAppListening } from '../..';
export const addSessionRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketSessionRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
log.error(
action.payload,
`Session retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketSessionRetrievalError(action.payload));
},
});
};

View File

@ -0,0 +1,56 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import {
MainModelConfigEntity,
modelsApi,
} from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addTabChangedListener = () => {
startAppListening({
actionCreator: setActiveTab,
effect: (action, { getState, dispatch }) => {
const activeTabName = action.payload;
if (activeTabName === 'unifiedCanvas') {
// grab the models from RTK Query cache
const { data } = modelsApi.endpoints.getMainModels.select(
NON_REFINER_BASE_MODELS
)(getState());
if (!data) {
// no models yet, so we can't do anything
dispatch(modelChanged(null));
return;
}
// need to filter out all the invalid canvas models (currently, this is just sdxl)
const validCanvasModels: MainModelConfigEntity[] = [];
forEach(data.entities, (entity) => {
if (!entity) {
return;
}
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
validCanvasModels.push(entity);
}
});
// this could still be undefined even tho TS doesn't say so
const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) {
// uh oh, we have no models that are valid for canvas
dispatch(modelChanged(null));
return;
}
// only store the model name and base model in redux
const { base_model, model_name } = firstValidCanvasModel;
dispatch(modelChanged({ base_model, model_name }));
}
},
});
};

View File

@ -39,8 +39,22 @@ export const addUserInvokedCanvasListener = () => {
const state = getState(); const state = getState();
const {
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea,
} = state.canvas;
// Build canvas blobs // Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(state); const canvasBlobsAndImageData = await getCanvasData(
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) { if (!canvasBlobsAndImageData) {
log.error('Unable to create canvas data'); log.error('Unable to create canvas data');

View File

@ -3,6 +3,7 @@ import { userInvoked } from 'app/store/actions';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions'; import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph'; import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
import { sessionReadyToInvoke } from 'features/system/store/actions'; import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionCreated } from 'services/api/thunks/session'; import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -14,8 +15,16 @@ export const addUserInvokedImageToImageListener = () => {
effect: async (action, { getState, dispatch, take }) => { effect: async (action, { getState, dispatch, take }) => {
const log = logger('session'); const log = logger('session');
const state = getState(); const state = getState();
const model = state.generation.model;
let graph;
if (model && model.base_model === 'sdxl') {
graph = buildLinearSDXLImageToImageGraph(state);
} else {
graph = buildLinearImageToImageGraph(state);
}
const graph = buildLinearImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph)); dispatch(imageToImageGraphBuilt(graph));
log.debug({ graph: parseify(graph) }, 'Image to Image graph built'); log.debug({ graph: parseify(graph) }, 'Image to Image graph built');

View File

@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { textToImageGraphBuilt } from 'features/nodes/store/actions'; import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph'; import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
import { sessionReadyToInvoke } from 'features/system/store/actions'; import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionCreated } from 'services/api/thunks/session'; import { sessionCreated } from 'services/api/thunks/session';
@ -14,8 +15,15 @@ export const addUserInvokedTextToImageListener = () => {
effect: async (action, { getState, dispatch, take }) => { effect: async (action, { getState, dispatch, take }) => {
const log = logger('session'); const log = logger('session');
const state = getState(); const state = getState();
const model = state.generation.model;
const graph = buildLinearTextToImageGraph(state); let graph;
if (model && model.base_model === 'sdxl') {
graph = buildLinearSDXLTextToImageGraph(state);
} else {
graph = buildLinearTextToImageGraph(state);
}
dispatch(textToImageGraphBuilt(graph)); dispatch(textToImageGraphBuilt(graph));

View File

@ -15,6 +15,7 @@ import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice'; import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import sdxlReducer from 'features/sdxl/store/sdxlSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice'; import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
@ -47,6 +48,7 @@ const allReducers = {
imageDeletion: imageDeletionReducer, imageDeletion: imageDeletionReducer,
lora: loraReducer, lora: loraReducer,
modelmanager: modelmanagerReducer, modelmanager: modelmanagerReducer,
sdxl: sdxlReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -58,6 +60,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'canvas', 'canvas',
'gallery', 'gallery',
'generation', 'generation',
'sdxl',
'nodes', 'nodes',
'postprocessing', 'postprocessing',
'system', 'system',

View File

@ -95,7 +95,8 @@ export type AppFeature =
| 'localization' | 'localization'
| 'consoleLogging' | 'consoleLogging'
| 'dynamicPrompting' | 'dynamicPrompting'
| 'batches'; | 'batches'
| 'syncModels';
/** /**
* A disable-able Stable Diffusion feature * A disable-able Stable Diffusion feature

View File

@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
// import { validateSeedWeights } from 'common/util/seedWeightPairs'; // import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { modelsApi } from '../../services/api/endpoints/models'; import { modelsApi } from '../../services/api/endpoints/models';
const readinessSelector = createSelector( const readinessSelector = createSelector(
@ -24,7 +25,7 @@ const readinessSelector = createSelector(
} }
const { isSuccess: mainModelsSuccessfullyLoaded } = const { isSuccess: mainModelsSuccessfullyLoaded } =
modelsApi.endpoints.getMainModels.select()(state); modelsApi.endpoints.getMainModels.select(NON_REFINER_BASE_MODELS)(state);
if (!mainModelsSuccessfullyLoaded) { if (!mainModelsSuccessfullyLoaded) {
isReady = false; isReady = false;
reasonsWhyNotReady.push('Models are not loaded'); reasonsWhyNotReady.push('Models are not loaded');

View File

@ -2,8 +2,8 @@ import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import GenerationModeStatusText from 'features/parameters/components/Parameters/Canvas/GenerationModeStatusText';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import roundToHundreth from '../util/roundToHundreth'; import roundToHundreth from '../util/roundToHundreth';
import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos'; import IAICanvasStatusTextCursorPos from './IAICanvasStatusText/IAICanvasStatusTextCursorPos';
@ -110,6 +110,7 @@ const IAICanvasStatusText = () => {
}, },
}} }}
> >
<GenerationModeStatusText />
<Box <Box
style={{ style={{
color: activeLayerColor, color: activeLayerColor,

View File

@ -2,7 +2,15 @@ import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useSingleAndDoubleClick } from 'common/hooks/useSingleAndDoubleClick'; import { useSingleAndDoubleClick } from 'common/hooks/useSingleAndDoubleClick';
import {
canvasCopiedToClipboard,
canvasDownloadedAsImage,
canvasMerged,
canvasSavedToGallery,
} from 'features/canvas/store/actions';
import { import {
canvasSelector, canvasSelector,
isStagingSelector, isStagingSelector,
@ -21,16 +29,8 @@ import {
} from 'features/canvas/store/canvasTypes'; } from 'features/canvas/store/canvasTypes';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider'; import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import {
canvasCopiedToClipboard,
canvasDownloadedAsImage,
canvasMerged,
canvasSavedToGallery,
} from 'features/canvas/store/actions';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { import {
@ -48,7 +48,6 @@ import IAICanvasRedoButton from './IAICanvasRedoButton';
import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover'; import IAICanvasSettingsButtonPopover from './IAICanvasSettingsButtonPopover';
import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions'; import IAICanvasToolChooserOptions from './IAICanvasToolChooserOptions';
import IAICanvasUndoButton from './IAICanvasUndoButton'; import IAICanvasUndoButton from './IAICanvasUndoButton';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
export const selector = createSelector( export const selector = createSelector(
[systemSelector, canvasSelector, isStagingSelector], [systemSelector, canvasSelector, isStagingSelector],
@ -220,7 +219,7 @@ const IAICanvasToolbar = () => {
}} }}
> >
<Box w={24}> <Box w={24}>
<IAIMantineSearchableSelect <IAIMantineSelect
tooltip={`${t('unifiedCanvas.layer')} (Q)`} tooltip={`${t('unifiedCanvas.layer')} (Q)`}
value={layer} value={layer}
data={LAYER_NAMES_DICT} data={LAYER_NAMES_DICT}

View File

@ -0,0 +1,72 @@
import { useAppSelector } from 'app/store/storeHooks';
import { GenerationMode } from 'features/canvas/store/canvasTypes';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { useEffect, useState } from 'react';
import { useDebounce } from 'react-use';
export const useCanvasGenerationMode = () => {
const layerState = useAppSelector((state) => state.canvas.layerState);
const boundingBoxCoordinates = useAppSelector(
(state) => state.canvas.boundingBoxCoordinates
);
const boundingBoxDimensions = useAppSelector(
(state) => state.canvas.boundingBoxDimensions
);
const isMaskEnabled = useAppSelector((state) => state.canvas.isMaskEnabled);
const shouldPreserveMaskedArea = useAppSelector(
(state) => state.canvas.shouldPreserveMaskedArea
);
const [generationMode, setGenerationMode] = useState<
GenerationMode | undefined
>();
useEffect(() => {
setGenerationMode(undefined);
}, [
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea,
]);
useDebounce(
async () => {
// Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea
);
if (!canvasBlobsAndImageData) {
return;
}
const { baseImageData, maskImageData } = canvasBlobsAndImageData;
// Determine the generation mode
const generationMode = getCanvasGenerationMode(
baseImageData,
maskImageData
);
setGenerationMode(generationMode);
},
1000,
[
layerState,
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea,
]
);
return generationMode;
};

View File

@ -168,4 +168,7 @@ export interface CanvasState {
stageDimensions: Dimensions; stageDimensions: Dimensions;
stageScale: number; stageScale: number;
tool: CanvasTool; tool: CanvasTool;
generationMode?: GenerationMode;
} }
export type GenerationMode = 'txt2img' | 'img2img' | 'inpaint' | 'outpaint';

View File

@ -1,6 +1,10 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store'; import { Vector2d } from 'konva/lib/types';
import { isCanvasMaskLine } from '../store/canvasTypes'; import {
CanvasLayerState,
Dimensions,
isCanvasMaskLine,
} from '../store/canvasTypes';
import createMaskStage from './createMaskStage'; import createMaskStage from './createMaskStage';
import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider'; import { getCanvasBaseLayer, getCanvasStage } from './konvaInstanceProvider';
import { konvaNodeToBlob } from './konvaNodeToBlob'; import { konvaNodeToBlob } from './konvaNodeToBlob';
@ -9,7 +13,13 @@ import { konvaNodeToImageData } from './konvaNodeToImageData';
/** /**
* Gets Blob and ImageData objects for the base and mask layers * Gets Blob and ImageData objects for the base and mask layers
*/ */
export const getCanvasData = async (state: RootState) => { export const getCanvasData = async (
layerState: CanvasLayerState,
boundingBoxCoordinates: Vector2d,
boundingBoxDimensions: Dimensions,
isMaskEnabled: boolean,
shouldPreserveMaskedArea: boolean
) => {
const log = logger('canvas'); const log = logger('canvas');
const canvasBaseLayer = getCanvasBaseLayer(); const canvasBaseLayer = getCanvasBaseLayer();
@ -20,14 +30,6 @@ export const getCanvasData = async (state: RootState) => {
return; return;
} }
const {
layerState: { objects },
boundingBoxCoordinates,
boundingBoxDimensions,
isMaskEnabled,
shouldPreserveMaskedArea,
} = state.canvas;
const boundingBox = { const boundingBox = {
...boundingBoxCoordinates, ...boundingBoxCoordinates,
...boundingBoxDimensions, ...boundingBoxDimensions,
@ -58,7 +60,7 @@ export const getCanvasData = async (state: RootState) => {
// For the mask layer, use the normal boundingBox // For the mask layer, use the normal boundingBox
const maskStage = await createMaskStage( const maskStage = await createMaskStage(
isMaskEnabled ? objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled isMaskEnabled ? layerState.objects.filter(isCanvasMaskLine) : [], // only include mask lines, and only if mask is enabled
boundingBox, boundingBox,
shouldPreserveMaskedArea shouldPreserveMaskedArea
); );

View File

@ -2,11 +2,12 @@ import {
areAnyPixelsBlack, areAnyPixelsBlack,
getImageDataTransparency, getImageDataTransparency,
} from 'common/util/arrayBuffer'; } from 'common/util/arrayBuffer';
import { GenerationMode } from '../store/canvasTypes';
export const getCanvasGenerationMode = ( export const getCanvasGenerationMode = (
baseImageData: ImageData, baseImageData: ImageData,
maskImageData: ImageData maskImageData: ImageData
) => { ): GenerationMode => {
const { const {
isPartiallyTransparent: baseIsPartiallyTransparent, isPartiallyTransparent: baseIsPartiallyTransparent,
isFullyTransparent: baseIsFullyTransparent, isFullyTransparent: baseIsFullyTransparent,

View File

@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent';
import UnetInputFieldComponent from './fields/UnetInputFieldComponent'; import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent'; import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
type InputFieldComponentProps = { type InputFieldComponentProps = {
nodeId: string; nodeId: string;
@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'refiner_model' && template.type === 'refiner_model') {
return (
<RefinerModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'vae_model' && template.type === 'vae_model') { if (type === 'vae_model' && template.type === 'vae_model') {
return ( return (
<VaeModelInputFieldComponent <VaeModelInputFieldComponent

View File

@ -14,8 +14,10 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
const ModelInputFieldComponent = ( const ModelInputFieldComponent = (
props: FieldComponentProps<MainModelInputFieldValue, ModelInputFieldTemplate> props: FieldComponentProps<MainModelInputFieldValue, ModelInputFieldTemplate>
@ -24,8 +26,11 @@ const ModelInputFieldComponent = (
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery(); const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!mainModels) {
@ -103,9 +108,11 @@ const ModelInputFieldComponent = (
disabled={data.length === 0} disabled={data.length === 0}
onChange={handleChangeModel} onChange={handleChangeModel}
/> />
{isSyncModelEnabled && (
<Box mt={7}> <Box mt={7}>
<SyncModelsButton iconMode /> <SyncModelsButton iconMode />
</Box> </Box>
)}
</Flex> </Flex>
); );
}; };

View File

@ -0,0 +1,120 @@
import { Box, Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
RefinerModelInputFieldTemplate,
RefinerModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
const RefinerModelInputFieldComponent = (
props: FieldComponentProps<
RefinerModelInputFieldValue,
RefinerModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {
return [];
}
const data: SelectItem[] = [];
forEach(refinerModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [refinerModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
refinerModels?.entities[
`${field.value?.base_model}/main/${field.value?.model_name}`
] ?? null,
[field.value?.base_model, field.value?.model_name, refinerModels?.entities]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: newModel,
})
);
},
[dispatch, field.name, nodeId]
);
return isLoading ? (
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<Flex w="100%" alignItems="center" gap={2}>
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
/>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton iconMode />
</Box>
)}
</Flex>
);
};
export default memo(RefinerModelInputFieldComponent);

View File

@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ClipField: 'clip', ClipField: 'clip',
VaeField: 'vae', VaeField: 'vae',
model: 'model', model: 'model',
refiner_model: 'refiner_model',
vae_model: 'vae_model', vae_model: 'vae_model',
lora_model: 'lora_model', lora_model: 'lora_model',
controlnet_model: 'controlnet_model', controlnet_model: 'controlnet_model',
@ -120,6 +121,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Model', title: 'Model',
description: 'Models are models.', description: 'Models are models.',
}, },
refiner_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'Refiner Model',
description: 'Models are models.',
},
vae_model: { vae_model: {
color: 'teal', color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'), colorCssVar: getColorTokenCssVariable('teal'),

View File

@ -70,6 +70,7 @@ export type FieldType =
| 'vae' | 'vae'
| 'control' | 'control'
| 'model' | 'model'
| 'refiner_model'
| 'vae_model' | 'vae_model'
| 'lora_model' | 'lora_model'
| 'controlnet_model' | 'controlnet_model'
@ -100,6 +101,7 @@ export type InputFieldValue =
| ControlInputFieldValue | ControlInputFieldValue
| EnumInputFieldValue | EnumInputFieldValue
| MainModelInputFieldValue | MainModelInputFieldValue
| RefinerModelInputFieldValue
| VaeModelInputFieldValue | VaeModelInputFieldValue
| LoRAModelInputFieldValue | LoRAModelInputFieldValue
| ControlNetModelInputFieldValue | ControlNetModelInputFieldValue
@ -128,6 +130,7 @@ export type InputFieldTemplate =
| ControlInputFieldTemplate | ControlInputFieldTemplate
| EnumInputFieldTemplate | EnumInputFieldTemplate
| ModelInputFieldTemplate | ModelInputFieldTemplate
| RefinerModelInputFieldTemplate
| VaeModelInputFieldTemplate | VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate | LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate | ControlNetModelInputFieldTemplate
@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & {
value?: MainModelParam; value?: MainModelParam;
}; };
export type RefinerModelInputFieldValue = FieldValueBase & {
type: 'refiner_model';
value?: MainModelParam;
};
export type VaeModelInputFieldValue = FieldValueBase & { export type VaeModelInputFieldValue = FieldValueBase & {
type: 'vae_model'; type: 'vae_model';
value?: VaeModelParam; value?: VaeModelParam;
@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'model'; type: 'model';
}; };
export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'refiner_model';
};
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
default: string; default: string;
type: 'vae_model'; type: 'vae_model';

View File

@ -22,6 +22,7 @@ import {
LoRAModelInputFieldTemplate, LoRAModelInputFieldTemplate,
ModelInputFieldTemplate, ModelInputFieldTemplate,
OutputFieldTemplate, OutputFieldTemplate,
RefinerModelInputFieldTemplate,
StringInputFieldTemplate, StringInputFieldTemplate,
TypeHints, TypeHints,
UNetInputFieldTemplate, UNetInputFieldTemplate,
@ -178,6 +179,21 @@ const buildModelInputFieldTemplate = ({
return template; return template;
}; };
const buildRefinerModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
const template: RefinerModelInputFieldTemplate = {
...baseField,
type: 'refiner_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVaeModelInputFieldTemplate = ({ const buildVaeModelInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -492,6 +508,9 @@ export const buildInputFieldTemplate = (
if (['model'].includes(fieldType)) { if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField }); return buildModelInputFieldTemplate({ schemaObject, baseField });
} }
if (['refiner_model'].includes(fieldType)) {
return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
}
if (['vae_model'].includes(fieldType)) { if (['vae_model'].includes(fieldType)) {
return buildVaeModelInputFieldTemplate({ schemaObject, baseField }); return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
} }

View File

@ -76,6 +76,10 @@ export const buildInputFieldValue = (
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'refiner_model') {
fieldValue.value = undefined;
}
if (template.type === 'vae_model') { if (template.type === 'vae_model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }

View File

@ -0,0 +1,70 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
ImageNSFWBlurInvocation,
LatentsToImageInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NSFW_CHECKER,
} from './constants';
export const addNSFWCheckerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate =
activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as
| LatentsToImageInvocation
| undefined;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
nodeToAddTo.is_intermediate = true;
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
id: NSFW_CHECKER,
type: 'img_nsfw',
is_intermediate,
};
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode;
graph.edges.push({
source: {
node_id: nodeIdToAddTo,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
});
if (metadataAccumulator) {
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: NSFW_CHECKER,
field: 'metadata',
},
});
}
};

View File

@ -0,0 +1,186 @@
import { RootState } from 'app/store/store';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
SDXL_LATENTS_TO_LATENTS,
SDXL_MODEL_LOADER,
SDXL_REFINER_LATENTS_TO_LATENTS,
SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING,
} from './constants';
export const addSDXLRefinerToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { positivePrompt, negativePrompt } = state.generation;
const {
refinerModel,
refinerAestheticScore,
positiveStylePrompt,
negativeStylePrompt,
refinerSteps,
refinerScheduler,
refinerCFGScale,
refinerStart,
} = state.sdxl;
if (!refinerModel) return;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (metadataAccumulator) {
metadataAccumulator.refiner_model = refinerModel;
metadataAccumulator.refiner_aesthetic_store = refinerAestheticScore;
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
metadataAccumulator.refiner_scheduler = refinerScheduler;
metadataAccumulator.refiner_start = refinerStart;
metadataAccumulator.refiner_steps = refinerSteps;
}
// Unplug SDXL Latents Generation To Latents To Image
graph.edges = graph.edges.filter(
(e) =>
!(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field))
);
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === SDXL_MODEL_LOADER &&
['vae'].includes(e.source.field)
)
);
// connect the VAE back to the i2l, which we just removed in the filter
// but only if we are doing l2l
if (baseNodeId === SDXL_LATENTS_TO_LATENTS) {
graph.edges.push({
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
});
}
graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
type: 'sdxl_refiner_model_loader',
id: SDXL_REFINER_MODEL_LOADER,
model: refinerModel,
};
graph.nodes[SDXL_REFINER_POSITIVE_CONDITIONING] = {
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_POSITIVE_CONDITIONING,
style: `${positivePrompt} ${positiveStylePrompt}`,
aesthetic_score: refinerAestheticScore,
};
graph.nodes[SDXL_REFINER_NEGATIVE_CONDITIONING] = {
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
style: `${negativePrompt} ${negativeStylePrompt}`,
aesthetic_score: refinerAestheticScore,
};
graph.nodes[SDXL_REFINER_LATENTS_TO_LATENTS] = {
type: 'l2l_sdxl',
id: SDXL_REFINER_LATENTS_TO_LATENTS,
cfg_scale: refinerCFGScale,
steps: refinerSteps / (1 - Math.min(refinerStart, 0.99)),
scheduler: refinerScheduler,
denoising_start: refinerStart,
denoising_end: 1,
};
graph.edges.push(
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: baseNodeId,
field: 'latents',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'latents',
},
},
{
source: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
}
);
};

View File

@ -0,0 +1,95 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
ImageNSFWBlurInvocation,
ImageWatermarkInvocation,
LatentsToImageInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NSFW_CHECKER,
WATERMARKER,
} from './constants';
export const addWatermarkerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate =
activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as
| LatentsToImageInvocation
| undefined;
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as
| ImageNSFWBlurInvocation
| undefined;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
const watermarkerNode: ImageWatermarkInvocation = {
id: WATERMARKER,
type: 'img_watermark',
is_intermediate,
};
graph.nodes[WATERMARKER] = watermarkerNode;
// no matter the situation, we want the l2i node to be intermediate
nodeToAddTo.is_intermediate = true;
if (nsfwCheckerNode) {
// if we are using NSFW checker, we need to "disable" it output by marking it intermediate,
// then connect it to the watermark node
nsfwCheckerNode.is_intermediate = true;
graph.edges.push({
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
});
} else {
// otherwise we just connect to the watermark node
graph.edges.push({
source: {
node_id: nodeIdToAddTo,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
});
}
if (metadataAccumulator) {
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: WATERMARKER,
field: 'metadata',
},
});
}
};

View File

@ -10,7 +10,9 @@ import {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
CLIP_SKIP, CLIP_SKIP,
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
@ -23,8 +25,6 @@ import {
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RESIZE, RESIZE,
NSFW_CHECKER,
WATERMARKER,
} from './constants'; } from './constants';
/** /**
@ -105,11 +105,6 @@ export const buildCanvasImageToImageGraph = (
is_intermediate: true, is_intermediate: true,
skipped_layers: clipSkip, skipped_layers: clipSkip,
}, },
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: true,
},
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
id: LATENTS_TO_LATENTS, id: LATENTS_TO_LATENTS,
@ -128,15 +123,10 @@ export const buildCanvasImageToImageGraph = (
// image_name: initialImage.image_name, // image_name: initialImage.image_name,
// }, // },
}, },
[NSFW_CHECKER]: { [LATENTS_TO_IMAGE]: {
type: 'img_nsfw', type: 'l2i',
id: NSFW_CHECKER, id: LATENTS_TO_IMAGE,
is_intermediate: true,
},
[WATERMARKER]: {
is_intermediate: !shouldAutoSave, is_intermediate: !shouldAutoSave,
type: 'img_watermark',
id: WATERMARKER,
}, },
}, },
edges: [ edges: [
@ -180,26 +170,6 @@ export const buildCanvasImageToImageGraph = (
field: 'latents', field: 'latents',
}, },
}, },
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
{ {
source: { source: {
node_id: IMAGE_TO_LATENTS, node_id: IMAGE_TO_LATENTS,
@ -342,17 +312,6 @@ export const buildCanvasImageToImageGraph = (
init_image: initialImage.image_name, init_image: initialImage.image_name,
}; };
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: WATERMARKER,
field: 'metadata',
},
});
// add LoRA support // add LoRA support
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS); addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
@ -365,5 +324,16 @@ export const buildCanvasImageToImageGraph = (
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS); addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph; return graph;
}; };

View File

@ -20,6 +20,8 @@ import {
RANDOM_INT, RANDOM_INT,
RANGE_OF_SIZE, RANGE_OF_SIZE,
} from './constants'; } from './constants';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
/** /**
* Builds the Canvas tab's Inpaint graph. * Builds the Canvas tab's Inpaint graph.
@ -249,5 +251,16 @@ export const buildCanvasInpaintGraph = (
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed; (graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
} }
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, INPAINT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, INPAINT);
}
return graph; return graph;
}; };

View File

@ -5,7 +5,9 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
CLIP_SKIP, CLIP_SKIP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
@ -16,8 +18,6 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
NSFW_CHECKER,
WATERMARKER,
} from './constants'; } from './constants';
/** /**
@ -109,16 +109,6 @@ export const buildCanvasTextToImageGraph = (
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
is_intermediate: true,
},
[NSFW_CHECKER]: {
type: 'img_nsfw',
id: NSFW_CHECKER,
is_intermediate: true,
},
[WATERMARKER]: {
type: 'img_watermark',
id: WATERMARKER,
is_intermediate: !shouldAutoSave, is_intermediate: !shouldAutoSave,
}, },
}, },
@ -193,26 +183,6 @@ export const buildCanvasTextToImageGraph = (
field: 'latents', field: 'latents',
}, },
}, },
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -247,17 +217,6 @@ export const buildCanvasTextToImageGraph = (
clip_skip: clipSkip, clip_skip: clipSkip,
}; };
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: WATERMARKER,
field: 'metadata',
},
});
// add LoRA support // add LoRA support
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS); addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
@ -270,5 +229,16 @@ export const buildCanvasTextToImageGraph = (
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS); addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph; return graph;
}; };

View File

@ -9,7 +9,9 @@ import {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
CLIP_SKIP, CLIP_SKIP,
IMAGE_TO_IMAGE_GRAPH, IMAGE_TO_IMAGE_GRAPH,
@ -22,8 +24,6 @@ import {
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RESIZE, RESIZE,
NSFW_CHECKER,
WATERMARKER,
} from './constants'; } from './constants';
/** /**
@ -48,6 +48,7 @@ export const buildLinearImageToImageGraph = (
clipSkip, clipSkip,
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
vaePrecision,
} = state.generation; } = state.generation;
// TODO: add batch functionality // TODO: add batch functionality
@ -115,7 +116,7 @@ export const buildLinearImageToImageGraph = (
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
is_intermediate: true, fp32: vaePrecision === 'fp32' ? true : false,
}, },
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
@ -131,15 +132,8 @@ export const buildLinearImageToImageGraph = (
// must be set manually later, bc `fit` parameter may require a resize node inserted // must be set manually later, bc `fit` parameter may require a resize node inserted
// image: { // image: {
// image_name: initialImage.image_name, // image_name: initialImage.image_name,
}, // },
[NSFW_CHECKER]: { fp32: vaePrecision === 'fp32' ? true : false,
type: 'img_nsfw',
id: NSFW_CHECKER,
is_intermediate: true,
},
[WATERMARKER]: {
type: 'img_watermark',
id: WATERMARKER,
}, },
}, },
edges: [ edges: [
@ -193,26 +187,6 @@ export const buildLinearImageToImageGraph = (
field: 'latents', field: 'latents',
}, },
}, },
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
{ {
source: { source: {
node_id: IMAGE_TO_LATENTS, node_id: IMAGE_TO_LATENTS,
@ -325,42 +299,6 @@ export const buildLinearImageToImageGraph = (
}); });
} }
// TODO: add batch functionality
// if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) {
// // we are going to connect an iterate up to the init image
// delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image;
// const imageCollection: ImageCollectionInvocation = {
// id: IMAGE_COLLECTION,
// type: 'image_collection',
// images: batchImageNames.map((image_name) => ({ image_name })),
// };
// const imageCollectionIterate: IterateInvocation = {
// id: IMAGE_COLLECTION_ITERATE,
// type: 'iterate',
// };
// graph.nodes[IMAGE_COLLECTION] = imageCollection;
// graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate;
// graph.edges.push({
// source: { node_id: IMAGE_COLLECTION, field: 'collection' },
// destination: {
// node_id: IMAGE_COLLECTION_ITERATE,
// field: 'collection',
// },
// });
// graph.edges.push({
// source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' },
// destination: {
// node_id: IMAGE_TO_LATENTS,
// field: 'image',
// },
// });
// }
// add metadata accumulator, which is only mostly populated - some fields are added later // add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = { graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR, id: METADATA_ACCUMULATOR,
@ -384,17 +322,6 @@ export const buildLinearImageToImageGraph = (
init_image: initialImage.imageName, init_image: initialImage.imageName,
}; };
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: WATERMARKER,
field: 'metadata',
},
});
// add LoRA support // add LoRA support
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS); addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
@ -407,5 +334,16 @@ export const buildLinearImageToImageGraph = (
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS); addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph; return graph;
}; };

View File

@ -0,0 +1,382 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import {
ImageResizeInvocation,
ImageToLatentsInvocation,
} from 'services/api/types';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import {
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RESIZE,
SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_LATENTS_TO_LATENTS,
SDXL_MODEL_LOADER,
} from './constants';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
/**
* Builds the Image to Image tab graph.
*/
export const buildLinearSDXLImageToImageGraph = (
state: RootState
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
initialImage,
shouldFitToWidthHeight,
width,
height,
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
} = state.generation;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
sdxlImg2ImgDenoisingStrength: strength,
} = state.sdxl;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
if (!initialImage) {
log.error('No initial image found in state');
throw new Error('No initial image found in state');
}
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: SDXL_IMAGE_TO_IMAGE_GRAPH,
nodes: {
[SDXL_MODEL_LOADER]: {
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
},
[SDXL_LATENTS_TO_LATENTS]: {
type: 'l2l_sdxl',
id: SDXL_LATENTS_TO_LATENTS,
cfg_scale,
scheduler,
steps,
denoising_start: shouldUseSDXLRefiner
? Math.min(refinerStart, 1 - strength)
: 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
id: IMAGE_TO_LATENTS,
// must be set manually later, bc `fit` parameter may require a resize node inserted
// image: {
// image_name: initialImage.image_name,
// },
fp32: vaePrecision === 'fp32' ? true : false,
},
},
edges: [
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
{
source: {
node_id: IMAGE_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'latents',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'negative_conditioning',
},
},
],
};
// handle `fit`
if (
shouldFitToWidthHeight &&
(initialImage.width !== width || initialImage.height !== height)
) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image
const resizeNode: ImageResizeInvocation = {
id: RESIZE,
type: 'img_resize',
image: {
image_name: initialImage.imageName,
},
is_intermediate: true,
width,
height,
};
graph.nodes[RESIZE] = resizeNode;
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
graph.edges.push({
source: { node_id: RESIZE, field: 'image' },
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
// The `RESIZE` node also passes its width and height to `NOISE`
graph.edges.push({
source: { node_id: RESIZE, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: RESIZE, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.imageName,
};
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
}
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'sdxl_img2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined,
controlnets: [],
loras: [],
clip_skip: clipSkip,
strength: strength,
init_image: initialImage.imageName,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS);
}
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -0,0 +1,264 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
SDXL_MODEL_LOADER,
SDXL_TEXT_TO_IMAGE_GRAPH,
SDXL_TEXT_TO_LATENTS,
} from './constants';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
export const buildLinearSDXLTextToImageGraph = (
state: RootState
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
width,
height,
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
} = state.generation;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
} = state.sdxl;
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: SDXL_TEXT_TO_IMAGE_GRAPH,
nodes: {
[SDXL_MODEL_LOADER]: {
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
width,
height,
use_cpu,
},
[SDXL_TEXT_TO_LATENTS]: {
type: 't2l_sdxl',
id: SDXL_TEXT_TO_LATENTS,
cfg_scale,
scheduler,
steps,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
},
},
edges: [
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
],
};
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'sdxl_txt2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined,
controlnets: [],
loras: [],
clip_skip: clipSkip,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);
}
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -5,7 +5,9 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph'; import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import { import {
CLIP_SKIP, CLIP_SKIP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
@ -16,8 +18,6 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
NSFW_CHECKER,
WATERMARKER,
} from './constants'; } from './constants';
export const buildLinearTextToImageGraph = ( export const buildLinearTextToImageGraph = (
@ -36,6 +36,7 @@ export const buildLinearTextToImageGraph = (
clipSkip, clipSkip,
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
vaePrecision,
} = state.generation; } = state.generation;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
@ -97,16 +98,7 @@ export const buildLinearTextToImageGraph = (
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
is_intermediate: true, fp32: vaePrecision === 'fp32' ? true : false,
},
[NSFW_CHECKER]: {
type: 'img_nsfw',
id: NSFW_CHECKER,
is_intermediate: true,
},
[WATERMARKER]: {
type: 'img_watermark',
id: WATERMARKER,
}, },
}, },
edges: [ edges: [
@ -190,26 +182,6 @@ export const buildLinearTextToImageGraph = (
field: 'noise', field: 'noise',
}, },
}, },
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
], ],
}; };
@ -234,17 +206,6 @@ export const buildLinearTextToImageGraph = (
clip_skip: clipSkip, clip_skip: clipSkip,
}; };
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: WATERMARKER,
field: 'metadata',
},
});
// add LoRA support // add LoRA support
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS); addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
@ -257,5 +218,16 @@ export const buildLinearTextToImageGraph = (
// add controlnet, mutating `graph` // add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS); addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph; return graph;
}; };

View File

@ -25,8 +25,19 @@ export const METADATA_ACCUMULATOR = 'metadata_accumulator';
export const REALESRGAN = 'esrgan'; export const REALESRGAN = 'esrgan';
export const DIVIDE = 'divide'; export const DIVIDE = 'divide';
export const SCALE = 'scale_image'; export const SCALE = 'scale_image';
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
export const SDXL_TEXT_TO_LATENTS = 't2l_sdxl';
export const SDXL_LATENTS_TO_LATENTS = 'l2l_sdxl';
export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader';
export const SDXL_REFINER_POSITIVE_CONDITIONING =
'sdxl_refiner_positive_conditioning';
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_LATENTS_TO_LATENTS = 'l2l_sdxl_refiner';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
export const SDXL_TEXT_TO_IMAGE_GRAPH = 'sdxl_text_to_image_graph';
export const SDXL_IMAGE_TO_IMAGE_GRAPH = 'sxdl_image_to_image_graph';
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph'; export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
export const INPAINT_GRAPH = 'inpaint_graph'; export const INPAINT_GRAPH = 'inpaint_graph';

View File

@ -13,7 +13,12 @@ import {
buildOutputFieldTemplates, buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; } from './fieldTemplateBuilders';
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'metadata']; const getReservedFieldNames = (type: string): string[] => {
if (type === 'l2i') {
return ['id', 'type', 'metadata'];
}
return ['id', 'type', 'is_intermediate', 'metadata'];
};
const invocationDenylist = [ const invocationDenylist = [
'Graph', 'Graph',
@ -21,11 +26,11 @@ const invocationDenylist = [
'MetadataAccumulatorInvocation', 'MetadataAccumulatorInvocation',
]; ];
export const parseSchema = (openAPI: OpenAPIV3.Document) => { export const parseSchema = (
// filter out non-invocation schemas, plus some tricky invocations for now openAPI: OpenAPIV3.Document
): Record<string, InvocationTemplate> => {
const filteredSchemas = filter( const filteredSchemas = filter(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion openAPI.components?.schemas,
openAPI.components!.schemas,
(schema, key) => (schema, key) =>
key.includes('Invocation') && key.includes('Invocation') &&
!key.includes('InvocationOutput') && !key.includes('InvocationOutput') &&
@ -35,21 +40,17 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
const invocations = filteredSchemas.reduce< const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate> Record<string, InvocationTemplate>
>((acc, schema) => { >((acc, schema) => {
// only want SchemaObjects
if (isInvocationSchemaObject(schema)) { if (isInvocationSchemaObject(schema)) {
const type = schema.properties.type.default; const type = schema.properties.type.default;
const RESERVED_FIELD_NAMES = getReservedFieldNames(type);
const title = schema.ui?.title ?? schema.title.replace('Invocation', ''); const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
const typeHints = schema.ui?.type_hints; const typeHints = schema.ui?.type_hints;
const inputs: Record<string, InputFieldTemplate> = {}; const inputs: Record<string, InputFieldTemplate> = {};
if (type === 'collect') { if (type === 'collect') {
const itemProperty = schema.properties[ const itemProperty = schema.properties.item as InvocationSchemaObject;
'item'
] as InvocationSchemaObject;
// Handle the special Collect node
inputs.item = { inputs.item = {
type: 'item', type: 'item',
name: 'item', name: 'item',
@ -60,10 +61,8 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
default: undefined, default: undefined,
}; };
} else if (type === 'iterate') { } else if (type === 'iterate') {
const itemProperty = schema.properties[ const itemProperty = schema.properties
'collection' .collection as InvocationSchemaObject;
] as InvocationSchemaObject;
inputs.collection = { inputs.collection = {
type: 'array', type: 'array',
name: 'collection', name: 'collection',
@ -74,18 +73,18 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
inputKind: 'connection', inputKind: 'connection',
}; };
} else { } else {
// All other nodes
reduce( reduce(
schema.properties, schema.properties,
(inputsAccumulator, property, propertyName) => { (inputsAccumulator, property, propertyName) => {
if ( if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) && !RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property) isSchemaObject(property)
) { ) {
const field: InputFieldTemplate | undefined = const field = buildInputFieldTemplate(
buildInputFieldTemplate(property, propertyName, typeHints); property,
propertyName,
typeHints
);
if (field) { if (field) {
inputsAccumulator[propertyName] = field; inputsAccumulator[propertyName] = field;
} }
@ -97,22 +96,17 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
} }
const rawOutput = (schema as InvocationSchemaObject).output; const rawOutput = (schema as InvocationSchemaObject).output;
let outputs: Record<string, OutputFieldTemplate>; let outputs: Record<string, OutputFieldTemplate>;
// some special handling is needed for collect, iterate and range nodes
if (type === 'iterate') { if (type === 'iterate') {
// this is guaranteed to be a SchemaObject const iterationOutput = openAPI.components?.schemas?.[
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const iterationOutput = openAPI.components!.schemas![
'IterateInvocationOutput' 'IterateInvocationOutput'
] as OpenAPIV3.SchemaObject; ] as OpenAPIV3.SchemaObject;
outputs = { outputs = {
item: { item: {
name: 'item', name: 'item',
title: iterationOutput.title ?? '', title: iterationOutput?.title ?? '',
description: iterationOutput.description ?? '', description: iterationOutput?.description ?? '',
type: 'array', type: 'array',
}, },
}; };

View File

@ -0,0 +1,21 @@
import { Box } from '@chakra-ui/react';
import { useCanvasGenerationMode } from 'features/canvas/hooks/useCanvasGenerationMode';
const GENERATION_MODE_NAME_MAP = {
txt2img: 'Text to Image',
img2img: 'Image to Image',
inpaint: 'Inpaint',
outpaint: 'Inpaint',
};
const GenerationModeStatusText = () => {
const generationMode = useCanvasGenerationMode();
return (
<Box>
Mode: {generationMode ? GENERATION_MODE_NAME_MAP[generationMode] : '...'}
</Box>
);
};
export default GenerationModeStatusText;

View File

@ -4,6 +4,7 @@ import { memo } from 'react';
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect'; import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect'; import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
import ParamScheduler from './ParamScheduler'; import ParamScheduler from './ParamScheduler';
import ParamVAEPrecision from '../VAEModel/ParamVAEPrecision';
const ParamModelandVAEandScheduler = () => { const ParamModelandVAEandScheduler = () => {
const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled; const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled;
@ -13,16 +14,15 @@ const ParamModelandVAEandScheduler = () => {
<Box w="full"> <Box w="full">
<ParamMainModelSelect /> <ParamMainModelSelect />
</Box> </Box>
<Flex gap={3} w="full">
{isVaeEnabled && (
<Box w="full">
<ParamVAEModelSelect />
</Box>
)}
<Box w="full"> <Box w="full">
<ParamScheduler /> <ParamScheduler />
</Box> </Box>
{isVaeEnabled && (
<Flex w="full" gap={3}>
<ParamVAEModelSelect />
<ParamVAEPrecision />
</Flex> </Flex>
)}
</Flex> </Flex>
); );
}; };

View File

@ -13,8 +13,11 @@ import { modelSelected } from 'features/parameters/store/actions';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
@ -28,7 +31,12 @@ const ParamMainModelSelect = () => {
const { model } = useAppSelector(selector); const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery(); const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const activeTabName = useAppSelector(activeTabNameSelector);
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!mainModels) {
@ -38,7 +46,10 @@ const ParamMainModelSelect = () => {
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => { forEach(mainModels.entities, (model, id) => {
if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) { if (
!model ||
(activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl')
) {
return; return;
} }
@ -50,7 +61,7 @@ const ParamMainModelSelect = () => {
}); });
return data; return data;
}, [mainModels]); }, [mainModels, activeTabName]);
// grab the full model entity from the RTK Query cache // grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state? // TODO: maybe we should just store the full model entity in state?
@ -86,7 +97,7 @@ const ParamMainModelSelect = () => {
data={[]} data={[]}
/> />
) : ( ) : (
<Flex w="100%" alignItems="center" gap={2}> <Flex w="100%" alignItems="center" gap={3}>
<IAIMantineSearchableSelect <IAIMantineSearchableSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.model')} label={t('modelManager.model')}
@ -98,9 +109,11 @@ const ParamMainModelSelect = () => {
onChange={handleChangeModel} onChange={handleChangeModel}
w="100%" w="100%"
/> />
{isSyncModelEnabled && (
<Box mt={7}> <Box mt={7}>
<SyncModelsButton iconMode /> <SyncModelsButton iconMode />
</Box> </Box>
)}
</Flex> </Flex>
); );
}; };

View File

@ -32,11 +32,6 @@ export default function ParamSeed() {
isInvalid={seed < 0 && shouldGenerateVariations} isInvalid={seed < 0 && shouldGenerateVariations}
onChange={handleChangeSeed} onChange={handleChangeSeed}
value={seed} value={seed}
formControlProps={{
display: 'flex',
alignItems: 'center',
gap: 3, // really this should work with 2 but seems to need to be 3 to match gap 2?
}}
/> />
); );
} }

View File

@ -6,7 +6,7 @@ import ParamSeedRandomize from './ParamSeedRandomize';
const ParamSeedFull = () => { const ParamSeedFull = () => {
return ( return (
<Flex sx={{ gap: 4, alignItems: 'center' }}> <Flex sx={{ gap: 3, alignItems: 'flex-end' }}>
<ParamSeed /> <ParamSeed />
<ParamSeedShuffle /> <ParamSeedShuffle />
<ParamSeedRandomize /> <ParamSeedRandomize />

View File

@ -0,0 +1,46 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { vaePrecisionChanged } from 'features/parameters/store/generationSlice';
import { PrecisionParam } from 'features/parameters/types/parameterSchemas';
import { memo, useCallback } from 'react';
const selector = createSelector(
stateSelector,
({ generation }) => {
const { vaePrecision } = generation;
return { vaePrecision };
},
defaultSelectorOptions
);
const DATA = ['fp16', 'fp32'];
const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
const { vaePrecision } = useAppSelector(selector);
const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(vaePrecisionChanged(v as PrecisionParam));
},
[dispatch]
);
return (
<IAIMantineSelect
label="VAE Precision"
value={vaePrecision}
data={DATA}
onChange={handleChange}
/>
);
};
export default memo(ParamVAEModelSelect);

View File

@ -11,6 +11,7 @@ import {
MainModelParam, MainModelParam,
NegativePromptParam, NegativePromptParam,
PositivePromptParam, PositivePromptParam,
PrecisionParam,
SchedulerParam, SchedulerParam,
SeedParam, SeedParam,
StepsParam, StepsParam,
@ -51,6 +52,7 @@ export interface GenerationState {
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: MainModelField | null; model: MainModelField | null;
vae: VaeModelParam | null; vae: VaeModelParam | null;
vaePrecision: PrecisionParam;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
clipSkip: number; clipSkip: number;
@ -89,6 +91,7 @@ export const initialGenerationState: GenerationState = {
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: null, model: null,
vae: null, vae: null,
vaePrecision: 'fp32',
seamlessXAxis: false, seamlessXAxis: false,
seamlessYAxis: false, seamlessYAxis: false,
clipSkip: 0, clipSkip: 0,
@ -241,6 +244,9 @@ export const generationSlice = createSlice({
// null is a valid VAE! // null is a valid VAE!
state.vae = action.payload; state.vae = action.payload;
}, },
vaePrecisionChanged: (state, action: PayloadAction<PrecisionParam>) => {
state.vaePrecision = action.payload;
},
setClipSkip: (state, action: PayloadAction<number>) => { setClipSkip: (state, action: PayloadAction<number>) => {
state.clipSkip = action.payload; state.clipSkip = action.payload;
}, },
@ -327,6 +333,7 @@ export const {
shouldUseCpuNoiseChanged, shouldUseCpuNoiseChanged,
setShouldShowAdvancedOptions, setShouldShowAdvancedOptions,
setAspectRatio, setAspectRatio,
vaePrecisionChanged,
} = generationSlice.actions; } = generationSlice.actions;
export default generationSlice.reducer; export default generationSlice.reducer;

View File

@ -42,6 +42,42 @@ export const isValidNegativePrompt = (
val: unknown val: unknown
): val is NegativePromptParam => zNegativePrompt.safeParse(val).success; ): val is NegativePromptParam => zNegativePrompt.safeParse(val).success;
/**
* Zod schema for SDXL positive style prompt parameter
*/
export const zPositiveStylePromptSDXL = z.string();
/**
* Type alias for SDXL positive style prompt parameter, inferred from its zod schema
*/
export type PositiveStylePromptSDXLParam = z.infer<
typeof zPositiveStylePromptSDXL
>;
/**
* Validates/type-guards a value as a SDXL positive style prompt parameter
*/
export const isValidSDXLPositiveStylePrompt = (
val: unknown
): val is PositiveStylePromptSDXLParam =>
zPositiveStylePromptSDXL.safeParse(val).success;
/**
* Zod schema for SDXL negative style prompt parameter
*/
export const zNegativeStylePromptSDXL = z.string();
/**
* Type alias for SDXL negative style prompt parameter, inferred from its zod schema
*/
export type NegativeStylePromptSDXLParam = z.infer<
typeof zNegativeStylePromptSDXL
>;
/**
* Validates/type-guards a value as a SDXL negative style prompt parameter
*/
export const isValidSDXLNegativeStylePrompt = (
val: unknown
): val is NegativeStylePromptSDXLParam =>
zNegativeStylePromptSDXL.safeParse(val).success;
/** /**
* Zod schema for steps parameter * Zod schema for steps parameter
*/ */
@ -260,6 +296,20 @@ export type StrengthParam = z.infer<typeof zStrength>;
export const isValidStrength = (val: unknown): val is StrengthParam => export const isValidStrength = (val: unknown): val is StrengthParam =>
zStrength.safeParse(val).success; zStrength.safeParse(val).success;
/**
* Zod schema for a precision parameter
*/
export const zPrecision = z.enum(['fp16', 'fp32']);
/**
* Type alias for precision parameter, inferred from its zod schema
*/
export type PrecisionParam = z.infer<typeof zPrecision>;
/**
* Validates/type-guards a value as a precision parameter
*/
export const isValidPrecision = (val: unknown): val is PrecisionParam =>
zPrecision.safeParse(val).success;
// /** // /**
// * Zod schema for BaseModelType // * Zod schema for BaseModelType
// */ // */

View File

@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { setSDXLImg2ImgDenoisingStrength } from '../store/sdxlSlice';
const selector = createSelector(
[stateSelector],
({ sdxl }) => {
const { sdxlImg2ImgDenoisingStrength } = sdxl;
return {
sdxlImg2ImgDenoisingStrength,
};
},
defaultSelectorOptions
);
const ParamSDXLImg2ImgDenoisingStrength = () => {
const { sdxlImg2ImgDenoisingStrength } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => dispatch(setSDXLImg2ImgDenoisingStrength(v)),
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(setSDXLImg2ImgDenoisingStrength(0.7));
}, [dispatch]);
return (
<IAISlider
label={`${t('parameters.denoisingStrength')}`}
step={0.01}
min={0}
max={1}
onChange={handleChange}
handleReset={handleReset}
value={sdxlImg2ImgDenoisingStrength}
isInteger={false}
withInput
withSliderMarks
withReset
/>
);
};
export default memo(ParamSDXLImg2ImgDenoisingStrength);

View File

@ -0,0 +1,149 @@
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { setNegativeStylePromptSDXL } from '../store/sdxlSlice';
const promptInputSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ sdxl }, activeTabName) => {
return {
prompt: sdxl.negativeStylePrompt,
activeTabName,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Prompt input text area.
*/
const ParamSDXLNegativeStyleConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setNegativeStylePromptSDXL(e.target.value));
},
[dispatch]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setNegativeStylePromptSDXL(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (isEmbeddingEnabled && e.key === '<') {
onOpen();
}
},
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box position="relative">
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder="Negative Style Prompt"
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
fontSize="sm"
minH={16}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};
export default ParamSDXLNegativeStyleConditioning;

View File

@ -0,0 +1,148 @@
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { setPositiveStylePromptSDXL } from '../store/sdxlSlice';
const promptInputSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ sdxl }, activeTabName) => {
return {
prompt: sdxl.positiveStylePrompt,
activeTabName,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Prompt input text area.
*/
const ParamSDXLPositiveStyleConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPositiveStylePromptSDXL(e.target.value));
},
[dispatch]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setPositiveStylePromptSDXL(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (isEmbeddingEnabled && e.key === '<') {
onOpen();
}
},
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box position="relative">
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder="Positive Style Prompt"
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
minH={16}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};
export default ParamSDXLPositiveStyleConditioning;

View File

@ -0,0 +1,48 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import ParamSDXLRefinerAestheticScore from './SDXLRefiner/ParamSDXLRefinerAestheticScore';
import ParamSDXLRefinerCFGScale from './SDXLRefiner/ParamSDXLRefinerCFGScale';
import ParamSDXLRefinerModelSelect from './SDXLRefiner/ParamSDXLRefinerModelSelect';
import ParamSDXLRefinerScheduler from './SDXLRefiner/ParamSDXLRefinerScheduler';
import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
const selector = createSelector(
stateSelector,
(state) => {
const { shouldUseSDXLRefiner } = state.sdxl;
const { shouldUseSliders } = state.ui;
return {
activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined,
shouldUseSliders,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerCollapse = () => {
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
return (
<IAICollapse label="Refiner" activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamUseSDXLRefiner />
<ParamSDXLRefinerModelSelect />
<Flex gap={2} flexDirection={shouldUseSliders ? 'column' : 'row'}>
<ParamSDXLRefinerSteps />
<ParamSDXLRefinerCFGScale />
</Flex>
<ParamSDXLRefinerScheduler />
<ParamSDXLRefinerAestheticScore />
<ParamSDXLRefinerStart />
</Flex>
</IAICollapse>
);
};
export default ParamSDXLRefinerCollapse;

View File

@ -0,0 +1,78 @@
import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamModelandVAEandScheduler from 'features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler';
import ParamSize from 'features/parameters/components/Parameters/Core/ParamSize';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
import ParamSDXLImg2ImgDenoisingStrength from './ParamSDXLImg2ImgDenoisingStrength';
const selector = createSelector(
[uiSelector, generationSelector],
(ui, generation) => {
const { shouldUseSliders } = ui;
const { shouldRandomizeSeed } = generation;
const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined;
return { shouldUseSliders, activeLabel };
},
defaultSelectorOptions
);
const SDXLImageToImageTabCoreParameters = () => {
const { shouldUseSliders, activeLabel } = useAppSelector(selector);
return (
<IAICollapse
label={'General'}
activeLabel={activeLabel}
defaultIsOpen={true}
>
<Flex
sx={{
flexDirection: 'column',
gap: 3,
}}
>
{shouldUseSliders ? (
<>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
<ParamModelandVAEandScheduler />
<Box pt={2}>
<ParamSeedFull />
</Box>
<ParamSize />
</>
) : (
<>
<Flex gap={3}>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
</Flex>
<ParamModelandVAEandScheduler />
<Box pt={2}>
<ParamSeedFull />
</Box>
<ParamSize />
</>
)}
<ParamSDXLImg2ImgDenoisingStrength />
<ImageToImageFit />
</Flex>
</IAICollapse>
);
};
export default memo(SDXLImageToImageTabCoreParameters);

View File

@ -0,0 +1,28 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
const SDXLImageToImageTabParameters = () => {
return (
<>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ProcessButtons />
<SDXLImageToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
);
};
export default SDXLImageToImageTabParameters;

View File

@ -0,0 +1,60 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl, hotkeys }) => {
const { refinerAestheticScore } = sdxl;
const { shift } = hotkeys;
return {
refinerAestheticScore,
shift,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerAestheticScore = () => {
const { refinerAestheticScore, shift } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerAestheticScore(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setRefinerAestheticScore(6)),
[dispatch]
);
return (
<IAISlider
label="Aesthetic Score"
step={shift ? 0.1 : 0.5}
min={1}
max={10}
onChange={handleChange}
handleReset={handleReset}
value={refinerAestheticScore}
sliderNumberInputProps={{ max: 10 }}
withInput
withReset
withSliderMarks
isInteger={false}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerAestheticScore);

View File

@ -0,0 +1,75 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl, ui, hotkeys }) => {
const { refinerCFGScale } = sdxl;
const { shouldUseSliders } = ui;
const { shift } = hotkeys;
return {
refinerCFGScale,
shouldUseSliders,
shift,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerCFGScale = () => {
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerCFGScale(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setRefinerCFGScale(7)),
[dispatch]
);
return shouldUseSliders ? (
<IAISlider
label={t('parameters.cfgScale')}
step={shift ? 0.1 : 0.5}
min={1}
max={20}
onChange={handleChange}
handleReset={handleReset}
value={refinerCFGScale}
sliderNumberInputProps={{ max: 200 }}
withInput
withReset
withSliderMarks
isInteger={false}
isDisabled={!isRefinerAvailable}
/>
) : (
<IAINumberInput
label={t('parameters.cfgScale')}
step={0.5}
min={1}
max={200}
onChange={handleChange}
value={refinerCFGScale}
isInteger={false}
numberInputFieldProps={{ textAlign: 'center' }}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerCFGScale);

View File

@ -0,0 +1,111 @@
import { Box, Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector(
stateSelector,
(state) => ({ model: state.sdxl.refinerModel }),
defaultSelectorOptions
);
const ParamSDXLRefinerModelSelect = () => {
const dispatch = useAppDispatch();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { model } = useAppSelector(selector);
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {
return [];
}
const data: SelectItem[] = [];
forEach(refinerModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [refinerModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
refinerModels?.entities[
`${model?.base_model}/main/${model?.model_name}`
] ?? null,
[refinerModels?.entities, model]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch(refinerModelChanged(newModel));
},
[dispatch]
);
return isLoading ? (
<IAIMantineSearchableSelect
label="Refiner Model"
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<Flex w="100%" alignItems="center" gap={2}>
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label="Refiner Model"
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
w="100%"
/>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton iconMode />
</Box>
)}
</Flex>
);
};
export default memo(ParamSDXLRefinerModelSelect);

View File

@ -0,0 +1,65 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
stateSelector,
({ ui, sdxl }) => {
const { refinerScheduler } = sdxl;
const { favoriteSchedulers: enabledSchedulers } = ui;
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
value: name,
label: label,
group: enabledSchedulers.includes(name as SchedulerParam)
? 'Favorites'
: undefined,
})).sort((a, b) => a.label.localeCompare(b.label));
return {
refinerScheduler,
data,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerScheduler = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { refinerScheduler, data } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(setRefinerScheduler(v as SchedulerParam));
},
[dispatch]
);
return (
<IAIMantineSearchableSelect
w="100%"
label={t('parameters.scheduler')}
value={refinerScheduler}
data={data}
onChange={handleChange}
disabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerScheduler);

View File

@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl }) => {
const { refinerStart } = sdxl;
return {
refinerStart,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerStart = () => {
const { refinerStart } = useAppSelector(selector);
const dispatch = useAppDispatch();
const isRefinerAvailable = useIsRefinerAvailable();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerStart(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setRefinerStart(0.7)),
[dispatch]
);
return (
<IAISlider
label="Refiner Start"
step={0.01}
min={0}
max={1}
onChange={handleChange}
handleReset={handleReset}
value={refinerStart}
withInput
withReset
withSliderMarks
isInteger={false}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerStart);

View File

@ -0,0 +1,72 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl, ui }) => {
const { refinerSteps } = sdxl;
const { shouldUseSliders } = ui;
return {
refinerSteps,
shouldUseSliders,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerSteps = () => {
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => {
dispatch(setRefinerSteps(v));
},
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(setRefinerSteps(20));
}, [dispatch]);
return shouldUseSliders ? (
<IAISlider
label={t('parameters.steps')}
min={1}
max={100}
step={1}
onChange={handleChange}
handleReset={handleReset}
value={refinerSteps}
withInput
withReset
withSliderMarks
sliderNumberInputProps={{ max: 500 }}
isDisabled={!isRefinerAvailable}
/>
) : (
<IAINumberInput
label={t('parameters.steps')}
min={1}
max={500}
step={1}
onChange={handleChange}
value={refinerSteps}
numberInputFieldProps={{ textAlign: 'center' }}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerSteps);

View File

@ -0,0 +1,28 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
import { ChangeEvent } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
export default function ParamUseSDXLRefiner() {
const shouldUseSDXLRefiner = useAppSelector(
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const handleUseSDXLRefinerChange = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(setShouldUseSDXLRefiner(e.target.checked));
};
return (
<IAISwitch
label="Use Refiner"
isChecked={shouldUseSDXLRefiner}
onChange={handleUseSDXLRefinerChange}
isDisabled={!isRefinerAvailable}
/>
);
}

View File

@ -0,0 +1,27 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
const SDXLTextToImageTabParameters = () => {
return (
<>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
);
};
export default SDXLTextToImageTabParameters;

View File

@ -0,0 +1,89 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import {
MainModelParam,
NegativeStylePromptSDXLParam,
PositiveStylePromptSDXLParam,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { MainModelField } from 'services/api/types';
type SDXLInitialState = {
positiveStylePrompt: PositiveStylePromptSDXLParam;
negativeStylePrompt: NegativeStylePromptSDXLParam;
shouldUseSDXLRefiner: boolean;
sdxlImg2ImgDenoisingStrength: number;
refinerModel: MainModelField | null;
refinerSteps: number;
refinerCFGScale: number;
refinerScheduler: SchedulerParam;
refinerAestheticScore: number;
refinerStart: number;
};
const sdxlInitialState: SDXLInitialState = {
positiveStylePrompt: '',
negativeStylePrompt: '',
shouldUseSDXLRefiner: false,
sdxlImg2ImgDenoisingStrength: 0.7,
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler',
refinerAestheticScore: 6,
refinerStart: 0.7,
};
const sdxlSlice = createSlice({
name: 'sdxl',
initialState: sdxlInitialState,
reducers: {
setPositiveStylePromptSDXL: (state, action: PayloadAction<string>) => {
state.positiveStylePrompt = action.payload;
},
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
state.negativeStylePrompt = action.payload;
},
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
state.shouldUseSDXLRefiner = action.payload;
},
setSDXLImg2ImgDenoisingStrength: (state, action: PayloadAction<number>) => {
state.sdxlImg2ImgDenoisingStrength = action.payload;
},
refinerModelChanged: (
state,
action: PayloadAction<MainModelParam | null>
) => {
state.refinerModel = action.payload;
},
setRefinerSteps: (state, action: PayloadAction<number>) => {
state.refinerSteps = action.payload;
},
setRefinerCFGScale: (state, action: PayloadAction<number>) => {
state.refinerCFGScale = action.payload;
},
setRefinerScheduler: (state, action: PayloadAction<SchedulerParam>) => {
state.refinerScheduler = action.payload;
},
setRefinerAestheticScore: (state, action: PayloadAction<number>) => {
state.refinerAestheticScore = action.payload;
},
setRefinerStart: (state, action: PayloadAction<number>) => {
state.refinerStart = action.payload;
},
},
});
export const {
setPositiveStylePromptSDXL,
setNegativeStylePromptSDXL,
setShouldUseSDXLRefiner,
setSDXLImg2ImgDenoisingStrength,
refinerModelChanged,
setRefinerSteps,
setRefinerCFGScale,
setRefinerScheduler,
setRefinerAestheticScore,
setRefinerStart,
} = sdxlSlice.actions;
export default sdxlSlice.reducer;

View File

@ -26,6 +26,8 @@ import {
setShouldConfirmOnDelete, setShouldConfirmOnDelete,
shouldAntialiasProgressImageChanged, shouldAntialiasProgressImageChanged,
shouldLogToConsoleChanged, shouldLogToConsoleChanged,
shouldUseNSFWCheckerChanged,
shouldUseWatermarkerChanged,
} from 'features/system/store/systemSlice'; } from 'features/system/store/systemSlice';
import { import {
setShouldShowProgressInViewer, setShouldShowProgressInViewer,
@ -42,6 +44,7 @@ import {
} from 'react'; } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { useGetAppConfigQuery } from 'services/api/endpoints/appInfo';
import SettingSwitch from './SettingSwitch'; import SettingSwitch from './SettingSwitch';
import SettingsClearIntermediates from './SettingsClearIntermediates'; import SettingsClearIntermediates from './SettingsClearIntermediates';
import SettingsSchedulers from './SettingsSchedulers'; import SettingsSchedulers from './SettingsSchedulers';
@ -57,6 +60,8 @@ const selector = createSelector(
shouldLogToConsole, shouldLogToConsole,
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
isNodesEnabled, isNodesEnabled,
shouldUseNSFWChecker,
shouldUseWatermarker,
} = system; } = system;
const { const {
@ -78,6 +83,8 @@ const selector = createSelector(
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
shouldShowAdvancedOptions, shouldShowAdvancedOptions,
isNodesEnabled, isNodesEnabled,
shouldUseNSFWChecker,
shouldUseWatermarker,
}; };
}, },
{ {
@ -120,6 +127,16 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
} }
}, [shouldShowDeveloperSettings, dispatch]); }, [shouldShowDeveloperSettings, dispatch]);
const { isNSFWCheckerAvailable, isWatermarkerAvailable } =
useGetAppConfigQuery(undefined, {
selectFromResult: ({ data }) => ({
isNSFWCheckerAvailable:
data?.nsfw_methods.includes('nsfw_checker') ?? false,
isWatermarkerAvailable:
data?.watermarking_methods.includes('invisible_watermark') ?? false,
}),
});
const { const {
isOpen: isSettingsModalOpen, isOpen: isSettingsModalOpen,
onOpen: onSettingsModalOpen, onOpen: onSettingsModalOpen,
@ -143,6 +160,8 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
shouldAntialiasProgressImage, shouldAntialiasProgressImage,
shouldShowAdvancedOptions, shouldShowAdvancedOptions,
isNodesEnabled, isNodesEnabled,
shouldUseNSFWChecker,
shouldUseWatermarker,
} = useAppSelector(selector); } = useAppSelector(selector);
const handleClickResetWebUI = useCallback(() => { const handleClickResetWebUI = useCallback(() => {
@ -221,6 +240,22 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
<StyledFlex> <StyledFlex>
<Heading size="sm">{t('settings.generation')}</Heading> <Heading size="sm">{t('settings.generation')}</Heading>
<SettingsSchedulers /> <SettingsSchedulers />
<SettingSwitch
label="Enable NSFW Checker"
isDisabled={!isNSFWCheckerAvailable}
isChecked={shouldUseNSFWChecker}
onChange={(e: ChangeEvent<HTMLInputElement>) =>
dispatch(shouldUseNSFWCheckerChanged(e.target.checked))
}
/>
<SettingSwitch
label="Enable Invisible Watermark"
isDisabled={!isWatermarkerAvailable}
isChecked={shouldUseWatermarker}
onChange={(e: ChangeEvent<HTMLInputElement>) =>
dispatch(shouldUseWatermarkerChanged(e.target.checked))
}
/>
</StyledFlex> </StyledFlex>
<StyledFlex> <StyledFlex>

View File

@ -1,5 +1,5 @@
import { UseToastOptions } from '@chakra-ui/react'; import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { InvokeLogLevel } from 'app/logging/logger'; import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
@ -16,13 +16,16 @@ import {
appSocketGraphExecutionStateComplete, appSocketGraphExecutionStateComplete,
appSocketInvocationComplete, appSocketInvocationComplete,
appSocketInvocationError, appSocketInvocationError,
appSocketInvocationRetrievalError,
appSocketInvocationStarted, appSocketInvocationStarted,
appSocketSessionRetrievalError,
appSocketSubscribed, appSocketSubscribed,
appSocketUnsubscribed, appSocketUnsubscribed,
} from 'services/events/actions'; } from 'services/events/actions';
import { ProgressImage } from 'services/events/types'; import { ProgressImage } from 'services/events/types';
import { makeToast } from '../util/makeToast'; import { makeToast } from '../util/makeToast';
import { LANGUAGES } from './constants'; import { LANGUAGES } from './constants';
import { startCase } from 'lodash-es';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -84,6 +87,8 @@ export interface SystemState {
language: keyof typeof LANGUAGES; language: keyof typeof LANGUAGES;
isUploading: boolean; isUploading: boolean;
isNodesEnabled: boolean; isNodesEnabled: boolean;
shouldUseNSFWChecker: boolean;
shouldUseWatermarker: boolean;
} }
export const initialSystemState: SystemState = { export const initialSystemState: SystemState = {
@ -116,6 +121,8 @@ export const initialSystemState: SystemState = {
language: 'en', language: 'en',
isUploading: false, isUploading: false,
isNodesEnabled: false, isNodesEnabled: false,
shouldUseNSFWChecker: true,
shouldUseWatermarker: true,
}; };
export const systemSlice = createSlice({ export const systemSlice = createSlice({
@ -191,6 +198,12 @@ export const systemSlice = createSlice({
setIsNodesEnabled(state, action: PayloadAction<boolean>) { setIsNodesEnabled(state, action: PayloadAction<boolean>) {
state.isNodesEnabled = action.payload; state.isNodesEnabled = action.payload;
}, },
shouldUseNSFWCheckerChanged(state, action: PayloadAction<boolean>) {
state.shouldUseNSFWChecker = action.payload;
},
shouldUseWatermarkerChanged(state, action: PayloadAction<boolean>) {
state.shouldUseWatermarker = action.payload;
},
}, },
extraReducers(builder) { extraReducers(builder) {
/** /**
@ -288,25 +301,6 @@ export const systemSlice = createSlice({
} }
}); });
/**
* Invocation Error
*/
builder.addCase(appSocketInvocationError, (state) => {
state.isProcessing = false;
state.isCancelable = true;
// state.currentIteration = 0;
// state.totalIterations = 0;
state.currentStatusHasSteps = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusError';
state.progressImage = null;
state.toastQueue.push(
makeToast({ title: t('toast.serverError'), status: 'error' })
);
});
/** /**
* Graph Execution State Complete * Graph Execution State Complete
*/ */
@ -362,7 +356,7 @@ export const systemSlice = createSlice({
* Session Invoked - REJECTED * Session Invoked - REJECTED
* Session Created - REJECTED * Session Created - REJECTED
*/ */
builder.addMatcher(isAnySessionRejected, (state) => { builder.addMatcher(isAnySessionRejected, (state, action) => {
state.isProcessing = false; state.isProcessing = false;
state.isCancelable = false; state.isCancelable = false;
state.isCancelScheduled = false; state.isCancelScheduled = false;
@ -372,7 +366,35 @@ export const systemSlice = createSlice({
state.progressImage = null; state.progressImage = null;
state.toastQueue.push( state.toastQueue.push(
makeToast({ title: t('toast.serverError'), status: 'error' }) makeToast({
title: t('toast.serverError'),
status: 'error',
description:
action.payload?.status === 422 ? 'Validation Error' : undefined,
})
);
});
/**
* Any server error
*/
builder.addMatcher(isAnyServerError, (state, action) => {
state.isProcessing = false;
state.isCancelable = true;
// state.currentIteration = 0;
// state.totalIterations = 0;
state.currentStatusHasSteps = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusError';
state.progressImage = null;
state.toastQueue.push(
makeToast({
title: t('toast.serverError'),
status: 'error',
description: startCase(action.payload.data.error_type),
})
); );
}); });
}, },
@ -397,6 +419,14 @@ export const {
languageChanged, languageChanged,
progressImageSet, progressImageSet,
setIsNodesEnabled, setIsNodesEnabled,
shouldUseNSFWCheckerChanged,
shouldUseWatermarkerChanged,
} = systemSlice.actions; } = systemSlice.actions;
export default systemSlice.reducer; export default systemSlice.reducer;
const isAnyServerError = isAnyOf(
appSocketInvocationError,
appSocketSessionRetrievalError,
appSocketInvocationRetrievalError
);

View File

@ -16,7 +16,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent'; import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
import { configSelector } from 'features/system/store/configSelectors'; import { configSelector } from 'features/system/store/configSelectors';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice'; import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
import { ResourceKey } from 'i18next'; import { ResourceKey } from 'i18next';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -172,13 +172,22 @@ const InvokeTabs = () => {
const { ref: galleryPanelRef, minSizePct: galleryMinSizePct } = const { ref: galleryPanelRef, minSizePct: galleryMinSizePct } =
useMinimumPanelSize(MIN_GALLERY_WIDTH, DEFAULT_GALLERY_PCT, 'app'); useMinimumPanelSize(MIN_GALLERY_WIDTH, DEFAULT_GALLERY_PCT, 'app');
const handleTabChange = useCallback(
(index: number) => {
const activeTabName = tabMap[index];
if (!activeTabName) {
return;
}
dispatch(setActiveTab(activeTabName));
},
[dispatch]
);
return ( return (
<Tabs <Tabs
defaultIndex={activeTab} defaultIndex={activeTab}
index={activeTab} index={activeTab}
onChange={(index: number) => { onChange={handleTabChange}
dispatch(setActiveTab(index));
}}
sx={{ sx={{
flexGrow: 1, flexGrow: 1,
gap: 4, gap: 4,

View File

@ -1,7 +1,9 @@
import { Box, Flex } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay'; import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay';
import SDXLImageToImageTabParameters from 'features/sdxl/components/SDXLImageToImageTabParameters';
import { memo, useCallback, useRef } from 'react'; import { memo, useCallback, useRef } from 'react';
import { import {
ImperativePanelGroupHandle, ImperativePanelGroupHandle,
@ -16,6 +18,7 @@ import ImageToImageTabParameters from './ImageToImageTabParameters';
const ImageToImageTab = () => { const ImageToImageTab = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null); const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const model = useAppSelector((state: RootState) => state.generation.model);
const handleDoubleClickHandle = useCallback(() => { const handleDoubleClickHandle = useCallback(() => {
if (!panelGroupRef.current) { if (!panelGroupRef.current) {
@ -28,7 +31,11 @@ const ImageToImageTab = () => {
return ( return (
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}> <Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
<ParametersPinnedWrapper> <ParametersPinnedWrapper>
{model && model.base_model === 'sdxl' ? (
<SDXLImageToImageTabParameters />
) : (
<ImageToImageTabParameters /> <ImageToImageTabParameters />
)}
</ParametersPinnedWrapper> </ParametersPinnedWrapper>
<Box sx={{ w: 'full', h: 'full' }}> <Box sx={{ w: 'full', h: 'full' }}>
<PanelGroup <PanelGroup

View File

@ -16,6 +16,7 @@ import {
useImportMainModelsMutation, useImportMainModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice'; import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import { ALL_BASE_MODELS } from 'services/api/constants';
export default function FoundModelsList() { export default function FoundModelsList() {
const searchFolder = useAppSelector( const searchFolder = useAppSelector(
@ -24,7 +25,7 @@ export default function FoundModelsList() {
const [nameFilter, setNameFilter] = useState<string>(''); const [nameFilter, setNameFilter] = useState<string>('');
// Get paths of models that are already installed // Get paths of models that are already installed
const { data: installedModels } = useGetMainModelsQuery(); const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
// Get all model paths from a given directory // Get all model paths from a given directory
const { foundModels, alreadyInstalled, filteredModels } = const { foundModels, alreadyInstalled, filteredModels } =

View File

@ -1,5 +1,4 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { pickBy } from 'lodash-es'; import { pickBy } from 'lodash-es';
import { useMemo, useState } from 'react'; import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { import {
useGetMainModelsQuery, useGetMainModelsQuery,
useMergeMainModelsMutation, useMergeMainModelsMutation,
@ -32,7 +33,7 @@ export default function MergeModelsPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery(); const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
const [mergeModels, { isLoading }] = useMergeMainModelsMutation(); const [mergeModels, { isLoading }] = useMergeMainModelsMutation();

View File

@ -8,10 +8,11 @@ import {
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
import { ALL_BASE_MODELS } from 'services/api/constants';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const [selectedModelId, setSelectedModelId] = useState<string>(); const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, { const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined, model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}), }),

View File

@ -11,6 +11,7 @@ import {
useGetMainModelsQuery, useGetMainModelsQuery,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = { type ModelListProps = {
selectedModelId: string | undefined; selectedModelId: string | undefined;
@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => {
const [modelFormatFilter, setModelFormatFilter] = const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('images'); useState<ModelFormat>('images');
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, { const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter), filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
}), }),
}); });
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, { const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter), filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
}), }),

View File

@ -1,14 +1,22 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import TextToImageSDXLTabParameters from 'features/sdxl/components/SDXLTextToImageTabParameters';
import { memo } from 'react'; import { memo } from 'react';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper'; import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
import TextToImageTabMain from './TextToImageTabMain'; import TextToImageTabMain from './TextToImageTabMain';
import TextToImageTabParameters from './TextToImageTabParameters'; import TextToImageTabParameters from './TextToImageTabParameters';
const TextToImageTab = () => { const TextToImageTab = () => {
const model = useAppSelector((state: RootState) => state.generation.model);
return ( return (
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}> <Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
<ParametersPinnedWrapper> <ParametersPinnedWrapper>
{model && model.base_model === 'sdxl' ? (
<TextToImageSDXLTabParameters />
) : (
<TextToImageTabParameters /> <TextToImageTabParameters />
)}
</ParametersPinnedWrapper> </ParametersPinnedWrapper>
<TextToImageTabMain /> <TextToImageTabMain />
</Flex> </Flex>

Some files were not shown because too many files have changed in this diff Show More