Merge branch 'main' into patch

This commit is contained in:
Lincoln Stein 2023-07-26 10:19:02 -04:00 committed by GitHub
commit c7f883d22a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
128 changed files with 5829 additions and 2668 deletions

View File

@ -1,11 +1,11 @@
name: Close inactive issues
on:
schedule:
- cron: "00 6 * * *"
- cron: "00 4 * * *"
env:
DAYS_BEFORE_ISSUE_STALE: 14
DAYS_BEFORE_ISSUE_CLOSE: 28
DAYS_BEFORE_ISSUE_STALE: 30
DAYS_BEFORE_ISSUE_CLOSE: 14
jobs:
close-issues:
@ -14,7 +14,7 @@ jobs:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v5
- uses: actions/stale@v8
with:
days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }}
days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }}
@ -23,5 +23,6 @@ jobs:
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
days-before-pr-stale: -1
days-before-pr-close: -1
exempt-issue-labels: "Active Issue"
repo-token: ${{ secrets.GITHUB_TOKEN }}
operations-per-run: 500

View File

@ -136,19 +136,16 @@ command-line options by giving the `--help` argument:
```
(.venv) > invokeai-web --help
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials]
[--allow_methods [ALLOW_METHODS ...]] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan]
[--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
[--nsfw_checker | --no-nsfw_checker] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_cache_size MAX_CACHE_SIZE]
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--precision {auto,float16,float32,autocast}]
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled]
[--tiled_decode | --no-tiled_decode] [--root ROOT] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR]
[--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH] [--models_dir MODELS_DIR]
[--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]]
[--log_format {plain,color,syslog,legacy}] [--log_level {debug,info,warning,error,critical}]
...
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials] [--allow_methods [ALLOW_METHODS ...]]
[--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan] [--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
[--nsfw_checker | --no-nsfw_checker] [--invisible_watermark | --no-invisible_watermark] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_loaded_models MAX_LOADED_MODELS] [--max_cache_size MAX_CACHE_SIZE]
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--gpu_mem_reserved GPU_MEM_RESERVED] [--precision {auto,float16,float32,autocast}]
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled] [--tiled_decode | --no-tiled_decode] [--root ROOT]
[--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR] [--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH]
[--models_dir MODELS_DIR] [--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]] [--log_format {plain,color,syslog,legacy}]
[--log_level {debug,info,warning,error,critical}] [--version | --no-version]
```
## The Configuration Settings
@ -179,6 +176,7 @@ These configuration settings allow you to enable and disable various InvokeAI fe
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
| `nsfw_checker` | `true` | Activate the NSFW checker to blur out risque images |
| `invisible_watermark` | `true` | Write an invisible watermark 'InvokeAI' into generated images for use by AI image detectors |
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |

View File

@ -61,11 +61,13 @@ A noise scheduler (eg. DPM++ 2M Karras) schedules the subtraction of noise from
| ImageInverseLerp | Inverse linear interpolation of all pixels of an image |
| ImageLerp | Linear interpolation of all pixels of an image |
| ImageMultiply | Multiplies two images together using `PIL.ImageChops.Multiply()` |
| ImageNSFWBlurInvocation | Detects and blurs images that may contain sexually explicit content |
| ImagePaste | Pastes an image into another image |
| ImageProcessor | Base class for invocations that reprocess images for ControlNet |
| ImageResize | Resizes an image to specific dimensions |
| ImageScale | Scales an image by a factor |
| ImageToLatents | Scales latents by a given factor |
| ImageWatermarkInvocation | Adds an invisible watermark to images |
| InfillColor | Infills transparent areas of an image with a solid color |
| InfillPatchMatch | Infills transparent areas of an image using the PatchMatch algorithm |
| InfillTile | Infills transparent areas of an image with tiles of the image |

View File

@ -16,21 +16,24 @@ Output Example:
---
## **Seamless Tiling**
## **Invisible Watermark**
The seamless tiling mode causes generated images to seamlessly tile
with itself creating repetitive wallpaper-like patterns. To use it,
activate the Seamless Tiling option in the Web GUI and then select
whether to tile on the X (horizontal) and/or Y (vertical) axes. Tiling
will then be active for the next set of generations.
In keeping with the principles for responsible AI generation, and to
help AI researchers avoid synthetic images contaminating their
training sets, InvokeAI adds an invisible watermark to each of the
final images it generates. The watermark consists of the text
"InvokeAI" and can be viewed using the
[invisible-watermarks](https://github.com/ShieldMnt/invisible-watermark)
tool.
A nice prompt to test seamless tiling with is:
Watermarking is controlled using the `invisible-watermark` setting in
`invokeai.yaml`. To turn it off, add the following line under the `Features`
category.
```
pond garden with lotus by claude monet"
invisible_watermark: false
```
---
## **Weighted Prompts**
@ -39,34 +42,10 @@ priority to them, by adding `:<percent>` to the end of the section you wish to u
example consider this prompt:
```bash
tabby cat:0.25 white duck:0.75 hybrid
(tabby cat):0.25 (white duck):0.75 hybrid
```
This will tell the sampler to invest 25% of its effort on the tabby cat aspect of the image and 75%
on the white duck aspect (surprisingly, this example actually works). The prompt weights can use any
combination of integers and floating point numbers, and they do not need to add up to 1.
## **Thresholding and Perlin Noise Initialization Options**
Under the Noise section of the Web UI, you will find two options named
Perlin Noise and Noise Threshold. [Perlin
noise](https://en.wikipedia.org/wiki/Perlin_noise) is a type of
structured noise used to simulate terrain and other natural
textures. The slider controls the percentage of perlin noise that will
be mixed into the image at the beginning of generation. Adding a little
perlin noise to a generation will alter the image substantially.
The noise threshold limits the range of the latent values during
sampling and helps combat the oversharpening seem with higher CFG
scale values.
For better intuition into what these options do in practice:
![here is a graphic demonstrating them both](../assets/truncation_comparison.jpg)
In generating this graphic, perlin noise at initialization was
programmatically varied going across on the diagram by values 0.0,
0.1, 0.2, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0; and the threshold was varied
going down from 0, 1, 2, 3, 4, 5, 10, 20, 100. The other options are
fixed using the prompt "a portrait of a beautiful young lady" a CFG of
20, 100 steps, and a seed of 1950357039.

View File

@ -1,9 +1,15 @@
import typing
from enum import Enum
from fastapi import Body
from fastapi.routing import APIRouter
from pathlib import Path
from pydantic import BaseModel, Field
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.version import __version__
from ..dependencies import ApiDependencies
@ -16,6 +22,10 @@ class LogLevel(int, Enum):
Warning = logging.WARNING
Error = logging.ERROR
Critical = logging.CRITICAL
class Upscaler(BaseModel):
upscaling_method: str = Field(description="Name of upscaling method")
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
app_router = APIRouter(prefix="/v1/app", tags=["app"])
@ -30,6 +40,9 @@ class AppConfig(BaseModel):
"""App Config Response"""
infill_methods: list[str] = Field(description="List of available infill methods")
upscaling_methods: list[Upscaler] = Field(description="List of upscaling methods")
nsfw_methods: list[str] = Field(description="List of NSFW checking methods")
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
@app_router.get(
@ -46,7 +59,30 @@ async def get_config() -> AppConfig:
infill_methods = ['tile']
if PatchMatch.patchmatch_available():
infill_methods.append('patchmatch')
return AppConfig(infill_methods=infill_methods)
upscaling_models = []
for model in typing.get_args(ESRGAN_MODELS):
upscaling_models.append(str(Path(model).stem))
upscaler = Upscaler(
upscaling_method = 'esrgan',
upscaling_models = upscaling_models
)
nsfw_methods = []
if SafetyChecker.safety_checker_available():
nsfw_methods.append('nsfw_checker')
watermarking_methods = []
if InvisibleWatermark.invisible_watermark_available():
watermarking_methods.append('invisible_watermark')
return AppConfig(
infill_methods=infill_methods,
upscaling_methods=[upscaler],
nsfw_methods=nsfw_methods,
watermarking_methods=watermarking_methods,
)
@app_router.get(
"/logging",

View File

@ -203,7 +203,10 @@ def invoke_api():
return find_port(port=port + 1)
else:
return port
from invokeai.backend.install.check_root import check_invokeai_root
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")

View File

Before

Width:  |  Height:  |  Size: 33 KiB

After

Width:  |  Height:  |  Size: 33 KiB

View File

@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation):
def _lora_loader():
for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight)
del lora_info
return
@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
**clip_field.tokenizer.dict(), context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
**clip_field.text_encoder.dict(), context=context,
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight)
del lora_info
return
@ -196,6 +196,7 @@ class SDXLPromptInvocationBase:
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
)
except ModelNotFoundException:
@ -240,16 +241,16 @@ class SDXLPromptInvocationBase:
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
**clip_field.tokenizer.dict(), context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
**clip_field.text_encoder.dict(), context=context,
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
**lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight)
del lora_info
return
@ -265,6 +266,7 @@ class SDXLPromptInvocationBase:
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
)
except ModelNotFoundException:

View File

@ -20,7 +20,7 @@ from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .image import ImageOutput, PILInvocationConfig
from ..models.image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################

View File

@ -4,61 +4,21 @@ from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field
from pydantic import Field
from pathlib import Path
from typing import Union
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from invokeai.app.invocations.metadata import CoreMetadata
from ..models.image import (
ImageCategory, ImageField, ResourceOrigin,
PILInvocationConfig, ImageOutput, MaskOutput,
)
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
class LoadImageInvocation(BaseInvocation):
"""Load an image and provide it as output."""
@ -397,7 +357,6 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
height=image_dto.height,
)
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
@ -602,7 +561,6 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
height=image_dto.height,
)
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
@ -650,3 +608,97 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
width=image_dto.width,
height=image_dto.height,
)
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Add blur to NSFW-flagged images"""
# fmt: off
type: Literal["img_nsfw"] = "img_nsfw"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Blur NSFW Images",
"tags": ["image", "nsfw", "checker"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
logger = context.services.logger
logger.debug("Running NSFW checker")
if SafetyChecker.has_nsfw_concept(image):
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution,(0,0),caution)
image = blurry_image
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
def _get_caution_img(self)->Image:
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / 'caution.png')
return caution.resize((caution.width // 2, caution.height //2))
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
""" Add an invisible watermark to an image """
# fmt: off
type: Literal["img_watermark"] = "img_watermark"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
text: str = Field(default='InvokeAI', description="Watermark text")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Add Invisible Watermark",
"tags": ["image", "watermark", "invisible"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
new_image = InvisibleWatermark.add_watermark(image, self.text)
image_dto = context.services.images.create(
image=new_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -501,7 +501,7 @@ class LatentsToImageInvocation(BaseInvocation):
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
description="Decode latents by overlapping tiles(less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")

View File

@ -2,16 +2,18 @@ from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
BaseInvocationOutput, InvocationConfig,
InvocationContext)
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationConfig,
InvocationContext,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
VAEModelField)
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model")
@ -19,7 +21,9 @@ class LoRAMetadataField(BaseModel):
class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(description="The generation mode that output this image",)
generation_mode: str = Field(
description="The generation mode that output this image",
)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
@ -29,10 +33,20 @@ class CoreMetadata(BaseModel):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",)
clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
controlnets: list[ControlField] = Field(
description="The ControlNets used for inference"
)
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
# Latents-to-Latents
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
@ -40,9 +54,34 @@ class CoreMetadata(BaseModel):
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field(
# SDXL
positive_style_prompt: Union[str, None] = Field(
default=None, description="The positive style prompt parameter"
)
negative_style_prompt: Union[str, None] = Field(
default=None, description="The negative style prompt parameter"
)
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(
default=None, description="The SDXL Refiner model used"
)
refiner_cfg_scale: Union[float, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(
default=None, description="The number of steps used for the refiner"
)
refiner_scheduler: Union[str, None] = Field(
default=None, description="The scheduler used for the refiner"
)
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(
default=None, description="The start value used for refiner denoising"
)
@ -71,7 +110,9 @@ class MetadataAccumulatorInvocation(BaseInvocation):
type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field(description="The generation mode that output this image",)
generation_mode: str = Field(
description="The generation mode that output this image",
)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
@ -81,9 +122,13 @@ class MetadataAccumulatorInvocation(BaseInvocation):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",)
clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
controlnets: list[ControlField] = Field(
description="The ControlNets used for inference"
)
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
default=None,
@ -97,36 +142,44 @@ class MetadataAccumulatorInvocation(BaseInvocation):
description="The VAE used for decoding, if the main model's default was not used",
)
# SDXL
positive_style_prompt: Union[str, None] = Field(
default=None, description="The positive style prompt parameter"
)
negative_style_prompt: Union[str, None] = Field(
default=None, description="The negative style prompt parameter"
)
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(
default=None, description="The SDXL Refiner model used"
)
refiner_cfg_scale: Union[float, None] = Field(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(
default=None, description="The number of steps used for the refiner"
)
refiner_scheduler: Union[str, None] = Field(
default=None, description="The scheduler used for the refiner"
)
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(
default=None, description="The start value used for refiner denoising"
)
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Metadata Accumulator",
"tags": ["image", "metadata", "generation"]
"tags": ["image", "metadata", "generation"],
},
}
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object"""
return MetadataAccumulatorOutput(
metadata=CoreMetadata(
generation_mode=self.generation_mode,
positive_prompt=self.positive_prompt,
negative_prompt=self.negative_prompt,
width=self.width,
height=self.height,
seed=self.seed,
rand_device=self.rand_device,
cfg_scale=self.cfg_scale,
steps=self.steps,
scheduler=self.scheduler,
model=self.model,
strength=self.strength,
init_image=self.init_image,
vae=self.vae,
controlnets=self.controlnets,
loras=self.loras,
clip_skip=self.clip_skip,
)
)
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))

View File

@ -138,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"ui": {
"title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "model"},
"type_hints": {"model": "refiner_model"},
},
}
@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
**self.unet.unet.dict(), context=context
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
@ -463,8 +463,8 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
unet: UNetField = Field(default=None, description="UNet submodel")
latents: Optional[LatentsField] = Field(description="Initial latents")
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
@ -549,13 +549,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
num_inference_steps = num_inference_steps - t_start
# apply noise(if provided)
if self.noise is not None:
if self.noise is not None and timesteps.shape[0] > 0:
noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
**self.unet.unet.dict(), context=context,
)
do_classifier_free_guidance = True
cross_attention_kwargs = None

View File

@ -1,9 +1,80 @@
from enum import Enum
from typing import Optional, Tuple
from typing import Optional, Tuple, Literal
from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationConfig,
)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
"""The origin of a resource (eg image).
@ -63,28 +134,3 @@ class InvalidImageCategoryException(ValueError):
super().__init__(message)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@ -28,7 +28,6 @@ InvokeAI:
always_use_cpu: false
free_gpu_mem: false
Features:
nsfw_checker: true
restore: true
esrgan: true
patchmatch: true
@ -92,18 +91,18 @@ Typical usage at the top level file:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value
# get global configuration and print its cache size
conf = InvokeAIAppConfig.get_config()
conf.parse_args()
print(conf.nsfw_checker)
print(conf.max_cache_size)
Typical usage in a backend module:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its nsfw_checker value
# get global configuration and print its cache size value
conf = InvokeAIAppConfig.get_config()
print(conf.nsfw_checker)
print(conf.max_cache_size)
Computed properties:
@ -277,7 +276,7 @@ class InvokeAISettings(BaseSettings):
@classmethod
def _excluded_from_yaml(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore', 'root']
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore', 'root', 'nsfw_checker']
class Config:
env_file_encoding = 'utf-8'
@ -364,7 +363,6 @@ setting environment variables INVOKEAI_<setting>.
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
@ -374,6 +372,7 @@ setting environment variables INVOKEAI_<setting>.
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
@ -525,6 +524,16 @@ setting environment variables INVOKEAI_<setting>.
"""Return true if patchmatch true"""
return self.patchmatch
@property
def nsfw_checker(self)->bool:
""" NSFW node is always active and disabled from Web UIe"""
return True
@property
def invisible_watermark(self)->bool:
""" invisible watermark node is always active and disabled from Web UIe"""
return True
@staticmethod
def find_root()->Path:
'''

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
from ..invocations.image import ImageNSFWBlurInvocation
from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation
@ -24,6 +25,7 @@ def create_text_to_image() -> LibraryGraph:
'5': CompelInvocation(id='5'),
'6': TextToLatentsInvocation(id='6'),
'7': LatentsToImageInvocation(id='7'),
'8': ImageNSFWBlurInvocation(id='8'),
},
edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
@ -33,6 +35,7 @@ def create_text_to_image() -> LibraryGraph:
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
Edge(source=EdgeConnection(node_id='7', field='image'), destination=EdgeConnection(node_id='8', field='image')),
]
),
exposed_inputs=[
@ -43,7 +46,7 @@ def create_text_to_image() -> LibraryGraph:
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
],
exposed_outputs=[
ExposedNodeOutput(node_path='7', field='image', alias='image')
ExposedNodeOutput(node_path='8', field='image', alias='image')
])

View File

@ -216,16 +216,13 @@ class ImageService(ImageServiceABC):
metadata=metadata,
session_id=session_id,
)
if board_id is not None:
self._services.board_image_records.add_image_to_board(
board_id=board_id, image_name=image_name
)
self._services.image_files.save(
image_name=image_name, image=image, metadata=metadata, graph=graph
)
image_dto = self.get_dto(image_name)
return image_dto
@ -236,7 +233,7 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error("Problem saving image record and file")
self._services.logger.error(f"Problem saving image record and file: {str(e)}")
raise e
def update(

View File

@ -12,4 +12,4 @@ from .model_management import (
ModelManager, ModelCache, BaseModelType,
ModelType, SubModelType, ModelInfo
)
from .safety_checker import SafetyChecker
from .model_management.models import SilenceWarnings

View File

@ -28,7 +28,6 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.backend.util.logging as logger
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP
@ -52,7 +51,6 @@ class InvokeAIGeneratorBasicParams:
v_symmetry_time_pct: Optional[float]=None
variation_amount: float = 0.0
with_variations: list=field(default_factory=list)
safety_checker: Optional[SafetyChecker]=None
@dataclass
class InvokeAIGeneratorOutput:
@ -240,7 +238,6 @@ class Generator:
self.seed = None
self.latent_channels = model.unet.config.in_channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
@ -277,12 +274,10 @@ class Generator:
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
safety_checker: SafetyChecker=None,
free_gpu_mem: bool = False,
**kwargs,
):
scope = nullcontext
self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(
@ -329,9 +324,6 @@ class Generator:
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
image = make_image(x_T, seed)
if self.safety_checker is not None:
image = self.safety_checker.check(image)
results.append([image, seed, attention_maps_images])
if image_callback is not None:

View File

@ -0,0 +1,34 @@
"""
This module defines a singleton object, "invisible_watermark" that
wraps the invisible watermark model. It respects the global "invisible_watermark"
configuration variable, that allows the watermarking to be supressed.
"""
import numpy as np
import cv2
from PIL import Image
from imwatermark import WatermarkEncoder
from invokeai.app.services.config import InvokeAIAppConfig
import invokeai.backend.util.logging as logger
config = InvokeAIAppConfig.get_config()
class InvisibleWatermark:
"""
Wrapper around InvisibleWatermark module.
"""
@classmethod
def invisible_watermark_available(self) -> bool:
return config.invisible_watermark
@classmethod
def add_watermark(self, image: Image, watermark_text:str) -> Image:
if not self.invisible_watermark_available():
return image
logger.debug(f'Applying invisible watermark "{watermark_text}"')
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
encoder = WatermarkEncoder()
encoder.set_watermark('bytes', watermark_text.encode('utf-8'))
bgr_encoded = encoder.encode(bgr, 'dwtDct')
return Image.fromarray(
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
).convert("RGBA")

View File

@ -0,0 +1,63 @@
"""
This module defines a singleton object, "safety_checker" that
wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
import numpy as np
from PIL import Image
from invokeai.backend import SilenceWarnings
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.devices import choose_torch_device
import invokeai.backend.util.logging as logger
config = InvokeAIAppConfig.get_config()
CHECKER_PATH = 'core/convert/stable-diffusion-safety-checker'
class SafetyChecker:
"""
Wrapper around SafetyChecker model.
"""
safety_checker = None
feature_extractor = None
tried_load: bool = False
@classmethod
def _load_safety_checker(self):
if self.tried_load:
return
if config.nsfw_checker:
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
config.models_path / CHECKER_PATH
)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
config.models_path / CHECKER_PATH)
logger.info('NSFW checker initialized')
except Exception as e:
logger.warning(f'Could not load NSFW checker: {str(e)}')
else:
logger.info('NSFW checker loading disabled')
self.tried_load = True
@classmethod
def safety_checker_available(self) -> bool:
self._load_safety_checker()
return self.safety_checker is not None
@classmethod
def has_nsfw_concept(self, image: Image) -> bool:
if not self.safety_checker_available():
return False
device = choose_torch_device()
features = self.feature_extractor([image], return_tensors="pt")
features.to(device)
self.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]

View File

@ -0,0 +1,31 @@
"""
Check that the invokeai_root is correctly configured and exit if not.
"""
import sys
from invokeai.app.services.config import (
InvokeAIAppConfig,
)
def check_invokeai_root(config: InvokeAIAppConfig):
try:
assert config.model_conf_path.exists()
assert config.db_path.exists()
assert config.models_path.exists()
for model in [
'CLIP-ViT-bigG-14-laion2B-39B-b160k',
'bert-base-uncased',
'clip-vit-large-patch14',
'sd-vae-ft-mse',
'stable-diffusion-2-clip',
'stable-diffusion-safety-checker']:
assert (config.models_path / f'core/convert/{model}').exists()
except:
print()
print('== STARTUP ABORTED ==')
print('** One or more necessary files is missing from your InvokeAI root directory **')
print('** Please rerun the configuration script to fix this problem. **')
print('** From the launcher, selection option [7]. **')
print('** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **')
input('Press any key to continue...')
sys.exit(0)

View File

@ -32,6 +32,7 @@ from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import (
CLIPTextModel,
CLIPTextConfig,
CLIPTokenizer,
AutoFeatureExtractor,
BertTokenizerFast,
@ -55,6 +56,7 @@ from invokeai.frontend.install.widgets import (
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import (
hf_download_from_pretrained,
hf_download_with_resume,
InstallSelections,
ModelInstall,
)
@ -204,6 +206,15 @@ def download_conversion_models():
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
# sd-xl - tokenizer_2
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
_, model_name = repo_id.split('/')
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs)
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
# VAE
logger.info('Downloading stable diffusion VAE')
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
@ -287,47 +298,6 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
color="CONTROL",
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== BASIC OPTIONS ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Select an output directory for images:",
editable=False,
color="CONTROL",
)
self.outdir = self.add_widget_intelligent(
npyscreen.TitleFilename,
name="(<tab> autocompletes, ctrl-N advances):",
value=str(default_output_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=40,
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.FixedText,
value="Activate the NSFW checker to blur images showing potential sexual imagery:",
editable=False,
color="CONTROL",
)
self.nsfw_checker = self.add_widget_intelligent(
npyscreen.Checkbox,
name="NSFW checker",
value=old_opts.nsfw_checker,
relx=5,
scroll_exit=True,
)
self.nextrely += 1
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
for line in textwrap.wrap(label,width=window_width-6):
@ -347,15 +317,6 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
scroll_exit=True,
)
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="== ADVANCED OPTIONS ==",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="GPU Management",
@ -415,6 +376,18 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
editable=False,
color="CONTROL",
)
self.outdir = self.add_widget_intelligent(
FileBox,
name="Output directory for images (<tab> autocompletes, ctrl-N advances):",
value=str(default_output_dir()),
select_dir=True,
must_exist=False,
use_two_lines=False,
labelColor="GOOD",
begin_entry_at=40,
max_height=3,
scroll_exit=True,
)
self.autoimport_dirs = {}
self.autoimport_dirs['autoimport_dir'] = self.add_widget_intelligent(
FileBox,
@ -506,7 +479,6 @@ https://huggingface.co/spaces/CompVis/stable-diffusion-license
for attr in [
"outdir",
"nsfw_checker",
"free_gpu_mem",
"max_cache_size",
"xformers_enabled",
@ -542,7 +514,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
"MAIN",
editOptsForm,
name="InvokeAI Startup Options",
cycle_widgets=True,
cycle_widgets=False,
)
if not (self.program_opts.skip_sd_weights or self.program_opts.default_only):
self.model_select = self.addForm(
@ -550,7 +522,7 @@ class EditOptApplication(npyscreen.NPSAppManaged):
addModelsForm,
name="Install Stable Diffusion Models",
multipage=True,
cycle_widgets=True,
cycle_widgets=False,
)
def new_opts(self):
@ -564,8 +536,6 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig.get_config()
if not init_file.exists():
opts.nsfw_checker = True
return opts
def default_user_selections(program_opts: Namespace) -> InstallSelections:
@ -689,7 +659,6 @@ def migrate_init_file(legacy_format:Path):
# a few places where the field names have changed and we have to
# manually add in the new names/values
new.nsfw_checker = old.safety_checker
new.xformers_enabled = old.xformers
new.conf_path = old.conf
new.root = legacy_format.parent.resolve()

View File

@ -58,7 +58,15 @@ LEGACY_CONFIGS = {
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
}
}
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: 'sd_xl_base.yaml',
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: 'sd_xl_refiner.yaml',
},
}
@dataclass
@ -330,6 +338,7 @@ class ModelInstall(object):
description = str(description),
model_format = info.format,
)
legacy_conf = None
if info.model_type == ModelType.Main:
attributes.update(dict(variant = info.variant_type,))
if info.format=="checkpoint":
@ -344,11 +353,17 @@ class ModelInstall(object):
except KeyError:
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
attributes.update(
dict(
config = str(legacy_conf)
)
if info.model_type == ModelType.ControlNet and info.format=="checkpoint":
possible_conf = path.with_suffix('.yaml')
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
if legacy_conf:
attributes.update(
dict(
config = str(legacy_conf)
)
)
return attributes
def relative_to_root(self, path: Path)->Path:

File diff suppressed because it is too large Load Diff

View File

@ -673,6 +673,7 @@ class ModelManager(object):
self.models[model_key] = model_config
self.commit()
return AddModelResult(
name = model_name,
model_type = model_type,
@ -840,7 +841,7 @@ class ModelManager(object):
Returns the preamble for the config file.
"""
return textwrap.dedent(
"""\
"""
# This file describes the alternative machine learning models
# available to InvokeAI script.
#

View File

@ -253,10 +253,13 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelException("Cannot determine base type")
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()

View File

@ -1,7 +1,8 @@
import os
import torch
from enum import Enum
from typing import Optional
from pathlib import Path
from typing import Optional, Literal
from .base import (
ModelBase,
ModelConfigBase,
@ -15,6 +16,7 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
class ControlNetModelFormat(str, Enum):
Checkpoint = "checkpoint"
@ -24,8 +26,12 @@ class ControlNetModel(ModelBase):
#model_class: Type
#model_size: int
class Config(ModelConfigBase):
model_format: ControlNetModelFormat
class DiffusersConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Diffusers]
class CheckpointConfig(ModelConfigBase):
model_format: Literal[ControlNetModelFormat.Checkpoint]
config: str
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.ControlNet
@ -99,13 +105,51 @@ class ControlNetModel(ModelBase):
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint:
return _convert_controlnet_ckpt_and_cache(
model_path = model_path,
model_config = config.config,
output_path = output_path,
base_model = base_model,
)
else:
return model_path
@classmethod
def _convert_controlnet_ckpt_and_cache(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != ControlNetModelFormat.Diffusers:
raise NotImplementedError("Checkpoint controlnet models currently unsupported")
else:
return model_path
model_config: ControlNetModel.CheckpointConfig,
) -> str:
"""
Convert the controlnet from checkpoint format to diffusers format,
cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path
output_path = Path(output_path)
# return cached version if it exists
if output_path.exists():
return output_path
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
convert_controlnet_to_diffusers(
weights,
output_path,
original_config_file = app_config.root_path / model_config,
image_size = 512,
scan_needed = True,
from_safetensors = weights.suffix == ".safetensors"
)
return output_path

View File

@ -1,5 +1,6 @@
import os
import json
import invokeai.backend.util.logging as logger
from enum import Enum
from pydantic import Field
from typing import Literal, Optional
@ -48,7 +49,7 @@ class StableDiffusionXLModel(DiffusersModel):
if model_format == StableDiffusionXLModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
@ -108,7 +109,20 @@ class StableDiffusionXLModel(DiffusersModel):
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
model_base_to_model_type = {BaseModelType.StableDiffusionXL: 'SDXL',
BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner',
}
if isinstance(config, cls.CheckpointConfig):
raise NotImplementedError('conversion of SDXL checkpoint models to diffusers format is not yet supported')
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
model_type=model_base_to_model_type[base_model],
use_safetensors=False, # corrupts sdxl models for some reason
)
else:
return model_path

View File

@ -15,9 +15,12 @@ from .base import (
classproperty,
InvalidModelException,
)
from .sdxl import StableDiffusionXLModel
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
@ -235,42 +238,17 @@ class StableDiffusion2Model(DiffusersModel):
else:
return model_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
# note that these .yaml files don't yet exist!
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "xl-inference-v.yaml",
ModelVariantType.Inpaint: "xl-inpainting-inference.yaml",
ModelVariantType.Depth: "xl-midas-inference.yaml",
}
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
return None
# TODO: rework
# Note that convert_ckpt_to_diffuses does not currently support conversion of SDXL models
# pass precision - currently defaulting to fp16
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
output_path: str,
version: BaseModelType,
model_config: Union[StableDiffusion1Model.CheckpointConfig,
StableDiffusion2Model.CheckpointConfig,
StableDiffusionXLModel.CheckpointConfig,
],
output_path: str,
use_save_model: bool=False,
**kwargs,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
@ -289,6 +267,9 @@ def _convert_ckpt_and_cache(
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from ...util.devices import choose_torch_device, torch_dtype
logger.info(f'Converting {weights} to diffusers format')
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
@ -298,5 +279,43 @@ def _convert_ckpt_and_cache(
original_config_file=config_file,
extract_ema=True,
scan_needed=True,
from_safetensors = weights.suffix == ".safetensors",
precision = torch_dtype(choose_torch_device()),
**kwargs,
)
return output_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
except:
return None

View File

@ -1,77 +0,0 @@
'''
SafetyChecker class - checks images against the StabilityAI NSFW filter
and blurs images that contain potential NSFW content.
'''
import diffusers
import numpy as np
import torch
import traceback
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from pathlib import Path
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .util import CPU_DEVICE
config = InvokeAIAppConfig.get_config()
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
def __init__(self, device: torch.device):
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device
try:
safety_model_id = config.models_path / 'core/convert/stable-diffusion-safety-checker'
feature_extractor_id = config.models_path / 'core/convert/stable-diffusion-safety-checker-extractor'
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_id)
except Exception:
logger.error(
"An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
def check(self, image: Image.Image):
"""
Check provided image against the StabilityAI safety checker and return
"""
self.safety_checker.to(self.device)
features = self.safety_feature_extractor([image], return_tensors="pt")
features.to(self.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = self.safety_checker(
images=x_image, clip_input=features.pixel_values
)
self.safety_checker.to(CPU_DEVICE) # offload
if has_nsfw_concept[0]:
logger.warning(
"An image with potential non-safe content has been detected. A blurred image will be returned."
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
if caution := self.caution_img:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""
invokeai.util.logging
invokeai.backend.util.logging
Logging class for InvokeAI that produces console messages

View File

@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -0,0 +1,91 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2560
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 384
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 4
context_dim: [1280, 1280, 1280, 1280] # 1280
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
legacy: False
freeze: True
layer: penultimate
always_return_pooled: True
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: aesthetic_score
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by one
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -553,7 +553,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
def onStart(self):
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
self.main_form = self.addForm(
"MAIN", addModelsForm, name="Install Stable Diffusion Models", cycle_widgets=True,
"MAIN", addModelsForm, name="Install Stable Diffusion Models", cycle_widgets=False,
)
class StderrToMessage():

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-e2437518.js"></script>
<script type="module" crossorigin src="./assets/index-e45bf5a6.js"></script>
</head>
<body dir="ltr">

View File

@ -102,8 +102,7 @@
"openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt",
"clearNodes": "Are you sure you want to clear all nodes?"
"imagePrompt": "Image Prompt"
},
"gallery": {
"generations": "Generations",
@ -615,6 +614,11 @@
"initialImageNotSetDesc": "Could not load initial image",
"nodesSaved": "Nodes Saved",
"nodesLoaded": "Nodes Loaded",
"nodesNotValidGraph": "Not a valid InvokeAI Node Graph",
"nodesNotValidJSON": "Not a valid JSON",
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
"nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
},
@ -700,9 +704,10 @@
},
"nodes": {
"reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes",
"loadNodes": "Load Nodes",
"clearNodes": "Clear Nodes",
"saveGraph": "Save Graph",
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
"clearGraph": "Clear Graph",
"clearGraphDesc": "Are you sure you want to clear all nodes?",
"zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out",
"fitViewportNodes": "Fit View",

View File

@ -53,11 +53,11 @@
]
},
"dependencies": {
"@chakra-ui/anatomy": "^2.1.1",
"@chakra-ui/anatomy": "^2.2.0",
"@chakra-ui/icons": "^2.0.19",
"@chakra-ui/react": "^2.7.1",
"@chakra-ui/react": "^2.8.0",
"@chakra-ui/styled-system": "^2.9.1",
"@chakra-ui/theme-tools": "^2.0.18",
"@chakra-ui/theme-tools": "^2.1.0",
"@dagrejs/graphlib": "^2.1.13",
"@dnd-kit/core": "^6.0.8",
"@dnd-kit/modifiers": "^6.0.1",

View File

@ -23,6 +23,7 @@
"menu": "Menu"
},
"common": {
"communityLabel": "Community",
"hotkeysLabel": "Hotkeys",
"darkMode": "Dark Mode",
"lightMode": "Light Mode",

View File

@ -65,18 +65,19 @@ import { addGeneratorProgressEventListener as addGeneratorProgressListener } fro
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
export const listenerMiddleware = createListenerMiddleware();
@ -201,3 +202,6 @@ addFirstListImagesListener();
// Ad-hoc upscale workflwo
addUpscaleRequestedListener();
// Tab Change
addTabChangedListener();

View File

@ -1,4 +1,8 @@
import { setInfillMethod } from 'features/parameters/store/generationSlice';
import {
shouldUseNSFWCheckerChanged,
shouldUseWatermarkerChanged,
} from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import { startAppListening } from '..';
@ -6,12 +10,21 @@ export const addAppConfigReceivedListener = () => {
startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
const { infill_methods } = action.payload;
const { infill_methods, nsfw_methods, watermarking_methods } =
action.payload;
const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) {
dispatch(setInfillMethod(infill_methods[0]));
}
if (!nsfw_methods.includes('nsfw_checker')) {
dispatch(shouldUseNSFWCheckerChanged(false));
}
if (!watermarking_methods.includes('invisible_watermark')) {
dispatch(shouldUseWatermarkerChanged(false));
}
},
});
};

View File

@ -9,13 +9,19 @@ import {
zMainModel,
zVaeModel,
} from 'features/parameters/types/parameterSchemas';
import {
refinerModelChanged,
setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
startAppListening({
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
predicate: (state, action) =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
@ -59,6 +65,54 @@ export const addModelsLoadedListener = () => {
dispatch(modelChanged(result.data));
},
});
startAppListening({
predicate: (state, action) =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info(
{ models: action.payload.entities },
`SDXL Refiner models loaded (${action.payload.ids.length})`
);
const currentModel = getState().sdxl.refinerModel;
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No models loaded at all
dispatch(refinerModelChanged(null));
dispatch(setShouldUseSDXLRefiner(false));
return;
}
const result = zMainModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse SDXL Refiner Model'
);
return;
}
dispatch(refinerModelChanged(result.data));
},
});
startAppListening({
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {

View File

@ -1,4 +1,6 @@
import { logger } from 'app/logging/logger';
import { LIST_TAG } from 'services/api';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import { modelsApi } from 'services/api/endpoints/models';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions';
@ -24,11 +26,18 @@ export const addSocketConnectedEventListener = () => {
dispatch(appSocketConnected(action.payload));
// update all server state
dispatch(modelsApi.endpoints.getMainModels.initiate());
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
dispatch(modelsApi.endpoints.getVaeModels.initiate());
dispatch(
modelsApi.util.invalidateTags([
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
{ type: 'LoRAModel', id: LIST_TAG },
{ type: 'ControlNetModel', id: LIST_TAG },
{ type: 'VaeModel', id: LIST_TAG },
{ type: 'TextualInversionModel', id: LIST_TAG },
{ type: 'ScannedModels', id: LIST_TAG },
])
);
dispatch(appInfoApi.util.invalidateTags(['AppConfig', 'AppVersion']));
},
});
};

View File

@ -21,7 +21,10 @@ export const addInvocationStartedEventListener = () => {
return;
}
log.debug(action.payload, 'Invocation started');
log.debug(
action.payload,
`Invocation started (${action.payload.data.node.type})`
);
dispatch(appSocketInvocationStarted(action.payload));
},
});

View File

@ -0,0 +1,56 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import {
MainModelConfigEntity,
modelsApi,
} from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addTabChangedListener = () => {
startAppListening({
actionCreator: setActiveTab,
effect: (action, { getState, dispatch }) => {
const activeTabName = action.payload;
if (activeTabName === 'unifiedCanvas') {
// grab the models from RTK Query cache
const { data } = modelsApi.endpoints.getMainModels.select(
NON_REFINER_BASE_MODELS
)(getState());
if (!data) {
// no models yet, so we can't do anything
dispatch(modelChanged(null));
return;
}
// need to filter out all the invalid canvas models (currently, this is just sdxl)
const validCanvasModels: MainModelConfigEntity[] = [];
forEach(data.entities, (entity) => {
if (!entity) {
return;
}
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
validCanvasModels.push(entity);
}
});
// this could still be undefined even tho TS doesn't say so
const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) {
// uh oh, we have no models that are valid for canvas
dispatch(modelChanged(null));
return;
}
// only store the model name and base model in redux
const { base_model, model_name } = firstValidCanvasModel;
dispatch(modelChanged({ base_model, model_name }));
}
},
});
};

View File

@ -3,6 +3,7 @@ import { userInvoked } from 'app/store/actions';
import { parseify } from 'common/util/serialize';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..';
@ -14,8 +15,16 @@ export const addUserInvokedImageToImageListener = () => {
effect: async (action, { getState, dispatch, take }) => {
const log = logger('session');
const state = getState();
const model = state.generation.model;
let graph;
if (model && model.base_model === 'sdxl') {
graph = buildLinearSDXLImageToImageGraph(state);
} else {
graph = buildLinearImageToImageGraph(state);
}
const graph = buildLinearImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph));
log.debug({ graph: parseify(graph) }, 'Image to Image graph built');

View File

@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions';
import { parseify } from 'common/util/serialize';
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionCreated } from 'services/api/thunks/session';
@ -14,8 +15,15 @@ export const addUserInvokedTextToImageListener = () => {
effect: async (action, { getState, dispatch, take }) => {
const log = logger('session');
const state = getState();
const model = state.generation.model;
const graph = buildLinearTextToImageGraph(state);
let graph;
if (model && model.base_model === 'sdxl') {
graph = buildLinearSDXLTextToImageGraph(state);
} else {
graph = buildLinearTextToImageGraph(state);
}
dispatch(textToImageGraphBuilt(graph));

View File

@ -15,6 +15,7 @@ import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import sdxlReducer from 'features/sdxl/store/sdxlSlice';
import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
@ -47,6 +48,7 @@ const allReducers = {
imageDeletion: imageDeletionReducer,
lora: loraReducer,
modelmanager: modelmanagerReducer,
sdxl: sdxlReducer,
[api.reducerPath]: api.reducer,
};
@ -58,6 +60,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'canvas',
'gallery',
'generation',
'sdxl',
'nodes',
'postprocessing',
'system',

View File

@ -114,6 +114,11 @@ const IAISlider = (props: IAIFullSliderProps) => {
setLocalInputValue(value);
}, [value]);
const numberInputMin = useMemo(
() => (sliderNumberInputProps?.min ? sliderNumberInputProps.min : min),
[min, sliderNumberInputProps?.min]
);
const numberInputMax = useMemo(
() => (sliderNumberInputProps?.max ? sliderNumberInputProps.max : max),
[max, sliderNumberInputProps?.max]
@ -129,24 +134,23 @@ const IAISlider = (props: IAIFullSliderProps) => {
const handleInputBlur = useCallback(
(e: FocusEvent<HTMLInputElement>) => {
if (e.target.value === '') {
e.target.value = String(min);
e.target.value = String(numberInputMin);
}
const clamped = clamp(
isInteger
? Math.floor(Number(e.target.value))
: Number(localInputValue),
min,
numberInputMin,
numberInputMax
);
const quantized = roundDownToMultiple(clamped, step);
onChange(quantized);
setLocalInputValue(quantized);
},
[isInteger, localInputValue, min, numberInputMax, onChange, step]
[isInteger, localInputValue, numberInputMin, numberInputMax, onChange, step]
);
const handleInputChange = useCallback((v: number | string) => {
console.log('input');
setLocalInputValue(v);
}, []);
@ -310,7 +314,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
{withInput && (
<NumberInput
min={min}
min={numberInputMin}
max={numberInputMax}
step={step}
value={localInputValue}

View File

@ -5,6 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
// import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { modelsApi } from '../../services/api/endpoints/models';
const readinessSelector = createSelector(
@ -24,7 +25,7 @@ const readinessSelector = createSelector(
}
const { isSuccess: mainModelsSuccessfullyLoaded } =
modelsApi.endpoints.getMainModels.select()(state);
modelsApi.endpoints.getMainModels.select(NON_REFINER_BASE_MODELS)(state);
if (!mainModelsSuccessfullyLoaded) {
isReady = false;
reasonsWhyNotReady.push('Models are not loaded');

View File

@ -57,6 +57,11 @@ const ParamEmbeddingPopover = (props: Props) => {
});
});
// Sort Alphabetically
data.sort((a, b) =>
a.label && b.label ? (a.label?.localeCompare(b.label) ? -1 : 1) : -1
);
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [embeddingQueryData, currentMainModel?.base_model]);

View File

@ -1,5 +1,6 @@
import { Link, MenuItem } from '@chakra-ui/react';
import { MenuItem } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -33,10 +34,9 @@ import {
useRemoveImageFromBoardMutation,
} from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
import { useDebounce } from 'use-debounce';
import { skipToken } from '@reduxjs/toolkit/dist/query';
type SingleSelectionMenuItemsProps = {
imageDTO: ImageDTO;
@ -154,21 +154,29 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
return (
<>
<Link href={imageDTO.image_url} target="_blank">
<MenuItem icon={<FaExternalLinkAlt />}>
{t('common.openInNewTab')}
</MenuItem>
</Link>
<MenuItem
as="a"
href={imageDTO.image_url}
target="_blank"
icon={<FaExternalLinkAlt />}
>
{t('common.openInNewTab')}
</MenuItem>
{isClipboardAPIAvailable && (
<MenuItem icon={<FaCopy />} onClickCapture={handleCopyImage}>
{t('parameters.copyImage')}
</MenuItem>
)}
<Link download={true} href={imageDTO.image_url} target="_blank">
<MenuItem icon={<FaDownload />} w="100%">
{t('parameters.downloadImage')}
</MenuItem>
</Link>
<MenuItem
as="a"
download={true}
href={imageDTO.image_url}
target="_blank"
icon={<FaDownload />}
w="100%"
>
{t('parameters.downloadImage')}
</MenuItem>
<MenuItem
icon={<FaQuoteRight />}
onClickCapture={handleRecallPrompt}

View File

@ -48,6 +48,7 @@ const ParamLora = (props: Props) => {
handleReset={handleReset}
withSliderMarks
sliderMarks={[-1, 0, 1, 2]}
sliderNumberInputProps={{ min: -50, max: 50 }}
/>
<IAIIconButton
size="sm"

View File

@ -54,7 +54,12 @@ const ParamLoRASelect = () => {
});
});
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
// Sort Alphabetically
data.sort((a, b) =>
a.label && b.label ? (a.label?.localeCompare(b.label) ? 1 : -1) : -1
);
return data.sort((a, b) => (a.disabled && !b.disabled ? -1 : 1));
}, [loras, loraModels, currentMainModel?.base_model]);
const handleChange = useCallback(

View File

@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent';
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'refiner_model' && template.type === 'refiner_model') {
return (
<RefinerModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'vae_model' && template.type === 'vae_model') {
return (
<VaeModelInputFieldComponent

View File

@ -14,6 +14,7 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
@ -27,7 +28,9 @@ const ModelInputFieldComponent = (
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const data = useMemo(() => {
if (!mainModels) {

View File

@ -0,0 +1,120 @@
import { Box, Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
RefinerModelInputFieldTemplate,
RefinerModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
const RefinerModelInputFieldComponent = (
props: FieldComponentProps<
RefinerModelInputFieldValue,
RefinerModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {
return [];
}
const data: SelectItem[] = [];
forEach(refinerModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [refinerModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
refinerModels?.entities[
`${field.value?.base_model}/main/${field.value?.model_name}`
] ?? null,
[field.value?.base_model, field.value?.model_name, refinerModels?.entities]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: newModel,
})
);
},
[dispatch, field.name, nodeId]
);
return isLoading ? (
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<Flex w="100%" alignItems="center" gap={2}>
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
/>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton iconMode />
</Box>
)}
</Flex>
);
};
export default memo(RefinerModelInputFieldComponent);

View File

@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ClipField: 'clip',
VaeField: 'vae',
model: 'model',
refiner_model: 'refiner_model',
vae_model: 'vae_model',
lora_model: 'lora_model',
controlnet_model: 'controlnet_model',
@ -120,6 +121,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Model',
description: 'Models are models.',
},
refiner_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'Refiner Model',
description: 'Models are models.',
},
vae_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),

View File

@ -70,6 +70,7 @@ export type FieldType =
| 'vae'
| 'control'
| 'model'
| 'refiner_model'
| 'vae_model'
| 'lora_model'
| 'controlnet_model'
@ -100,6 +101,7 @@ export type InputFieldValue =
| ControlInputFieldValue
| EnumInputFieldValue
| MainModelInputFieldValue
| RefinerModelInputFieldValue
| VaeModelInputFieldValue
| LoRAModelInputFieldValue
| ControlNetModelInputFieldValue
@ -128,6 +130,7 @@ export type InputFieldTemplate =
| ControlInputFieldTemplate
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| RefinerModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate
@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & {
value?: MainModelParam;
};
export type RefinerModelInputFieldValue = FieldValueBase & {
type: 'refiner_model';
value?: MainModelParam;
};
export type VaeModelInputFieldValue = FieldValueBase & {
type: 'vae_model';
value?: VaeModelParam;
@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'model';
};
export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'refiner_model';
};
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'vae_model';

View File

@ -22,6 +22,7 @@ import {
LoRAModelInputFieldTemplate,
ModelInputFieldTemplate,
OutputFieldTemplate,
RefinerModelInputFieldTemplate,
StringInputFieldTemplate,
TypeHints,
UNetInputFieldTemplate,
@ -178,6 +179,21 @@ const buildModelInputFieldTemplate = ({
return template;
};
const buildRefinerModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
const template: RefinerModelInputFieldTemplate = {
...baseField,
type: 'refiner_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVaeModelInputFieldTemplate = ({
schemaObject,
baseField,
@ -492,6 +508,9 @@ export const buildInputFieldTemplate = (
if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField });
}
if (['refiner_model'].includes(fieldType)) {
return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
}
if (['vae_model'].includes(fieldType)) {
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
}

View File

@ -76,6 +76,10 @@ export const buildInputFieldValue = (
fieldValue.value = undefined;
}
if (template.type === 'refiner_model') {
fieldValue.value = undefined;
}
if (template.type === 'vae_model') {
fieldValue.value = undefined;
}

View File

@ -0,0 +1,70 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
ImageNSFWBlurInvocation,
LatentsToImageInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NSFW_CHECKER,
} from './constants';
export const addNSFWCheckerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate =
activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as
| LatentsToImageInvocation
| undefined;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
nodeToAddTo.is_intermediate = true;
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
id: NSFW_CHECKER,
type: 'img_nsfw',
is_intermediate,
};
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode;
graph.edges.push({
source: {
node_id: nodeIdToAddTo,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
});
if (metadataAccumulator) {
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: NSFW_CHECKER,
field: 'metadata',
},
});
}
};

View File

@ -0,0 +1,186 @@
import { RootState } from 'app/store/store';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
SDXL_LATENTS_TO_LATENTS,
SDXL_MODEL_LOADER,
SDXL_REFINER_LATENTS_TO_LATENTS,
SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING,
} from './constants';
export const addSDXLRefinerToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { positivePrompt, negativePrompt } = state.generation;
const {
refinerModel,
refinerAestheticScore,
positiveStylePrompt,
negativeStylePrompt,
refinerSteps,
refinerScheduler,
refinerCFGScale,
refinerStart,
} = state.sdxl;
if (!refinerModel) return;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (metadataAccumulator) {
metadataAccumulator.refiner_model = refinerModel;
metadataAccumulator.refiner_aesthetic_store = refinerAestheticScore;
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
metadataAccumulator.refiner_scheduler = refinerScheduler;
metadataAccumulator.refiner_start = refinerStart;
metadataAccumulator.refiner_steps = refinerSteps;
}
// Unplug SDXL Latents Generation To Latents To Image
graph.edges = graph.edges.filter(
(e) =>
!(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field))
);
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === SDXL_MODEL_LOADER &&
['vae'].includes(e.source.field)
)
);
// connect the VAE back to the i2l, which we just removed in the filter
// but only if we are doing l2l
if (baseNodeId === SDXL_LATENTS_TO_LATENTS) {
graph.edges.push({
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
});
}
graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
type: 'sdxl_refiner_model_loader',
id: SDXL_REFINER_MODEL_LOADER,
model: refinerModel,
};
graph.nodes[SDXL_REFINER_POSITIVE_CONDITIONING] = {
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_POSITIVE_CONDITIONING,
style: `${positivePrompt} ${positiveStylePrompt}`,
aesthetic_score: refinerAestheticScore,
};
graph.nodes[SDXL_REFINER_NEGATIVE_CONDITIONING] = {
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
style: `${negativePrompt} ${negativeStylePrompt}`,
aesthetic_score: refinerAestheticScore,
};
graph.nodes[SDXL_REFINER_LATENTS_TO_LATENTS] = {
type: 'l2l_sdxl',
id: SDXL_REFINER_LATENTS_TO_LATENTS,
cfg_scale: refinerCFGScale,
steps: refinerSteps / (1 - Math.min(refinerStart, 0.99)),
scheduler: refinerScheduler,
denoising_start: refinerStart,
denoising_end: 1,
};
graph.edges.push(
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: baseNodeId,
field: 'latents',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'latents',
},
},
{
source: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
}
);
};

View File

@ -0,0 +1,95 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
ImageNSFWBlurInvocation,
ImageWatermarkInvocation,
LatentsToImageInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NSFW_CHECKER,
WATERMARKER,
} from './constants';
export const addWatermarkerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate =
activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as
| LatentsToImageInvocation
| undefined;
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as
| ImageNSFWBlurInvocation
| undefined;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
const watermarkerNode: ImageWatermarkInvocation = {
id: WATERMARKER,
type: 'img_watermark',
is_intermediate,
};
graph.nodes[WATERMARKER] = watermarkerNode;
// no matter the situation, we want the l2i node to be intermediate
nodeToAddTo.is_intermediate = true;
if (nsfwCheckerNode) {
// if we are using NSFW checker, we need to "disable" it output by marking it intermediate,
// then connect it to the watermark node
nsfwCheckerNode.is_intermediate = true;
graph.edges.push({
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
});
} else {
// otherwise we just connect to the watermark node
graph.edges.push({
source: {
node_id: nodeIdToAddTo,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
});
}
if (metadataAccumulator) {
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: WATERMARKER,
field: 'metadata',
},
});
}
};

View File

@ -10,7 +10,9 @@ import {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CLIP_SKIP,
IMAGE_TO_IMAGE_GRAPH,
@ -103,11 +105,6 @@ export const buildCanvasImageToImageGraph = (
is_intermediate: true,
skipped_layers: clipSkip,
},
[LATENTS_TO_IMAGE]: {
is_intermediate: !shouldAutoSave,
type: 'l2i',
id: LATENTS_TO_IMAGE,
},
[LATENTS_TO_LATENTS]: {
type: 'l2l',
id: LATENTS_TO_LATENTS,
@ -126,6 +123,11 @@ export const buildCanvasImageToImageGraph = (
// image_name: initialImage.image_name,
// },
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: !shouldAutoSave,
},
},
edges: [
{
@ -310,17 +312,6 @@ export const buildCanvasImageToImageGraph = (
init_image: initialImage.image_name,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add LoRA support
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
@ -333,5 +324,16 @@ export const buildCanvasImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -20,6 +20,8 @@ import {
RANDOM_INT,
RANGE_OF_SIZE,
} from './constants';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
/**
* Builds the Canvas tab's Inpaint graph.
@ -249,5 +251,16 @@ export const buildCanvasInpaintGraph = (
(graph.nodes[RANGE_OF_SIZE] as RangeOfSizeInvocation).start = seed;
}
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, INPAINT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, INPAINT);
}
return graph;
};

View File

@ -5,7 +5,9 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CLIP_SKIP,
LATENTS_TO_IMAGE,
@ -215,17 +217,6 @@ export const buildCanvasTextToImageGraph = (
clip_skip: clipSkip,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add LoRA support
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
@ -238,5 +229,16 @@ export const buildCanvasTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -9,7 +9,9 @@ import {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CLIP_SKIP,
IMAGE_TO_IMAGE_GRAPH,
@ -46,6 +48,7 @@ export const buildLinearImageToImageGraph = (
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
} = state.generation;
// TODO: add batch functionality
@ -113,6 +116,7 @@ export const buildLinearImageToImageGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
},
[LATENTS_TO_LATENTS]: {
type: 'l2l',
@ -129,6 +133,7 @@ export const buildLinearImageToImageGraph = (
// image: {
// image_name: initialImage.image_name,
// },
fp32: vaePrecision === 'fp32' ? true : false,
},
},
edges: [
@ -294,42 +299,6 @@ export const buildLinearImageToImageGraph = (
});
}
// TODO: add batch functionality
// if (isBatchEnabled && asInitialImage && batchImageNames.length > 0) {
// // we are going to connect an iterate up to the init image
// delete (graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image;
// const imageCollection: ImageCollectionInvocation = {
// id: IMAGE_COLLECTION,
// type: 'image_collection',
// images: batchImageNames.map((image_name) => ({ image_name })),
// };
// const imageCollectionIterate: IterateInvocation = {
// id: IMAGE_COLLECTION_ITERATE,
// type: 'iterate',
// };
// graph.nodes[IMAGE_COLLECTION] = imageCollection;
// graph.nodes[IMAGE_COLLECTION_ITERATE] = imageCollectionIterate;
// graph.edges.push({
// source: { node_id: IMAGE_COLLECTION, field: 'collection' },
// destination: {
// node_id: IMAGE_COLLECTION_ITERATE,
// field: 'collection',
// },
// });
// graph.edges.push({
// source: { node_id: IMAGE_COLLECTION_ITERATE, field: 'item' },
// destination: {
// node_id: IMAGE_TO_LATENTS,
// field: 'image',
// },
// });
// }
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
@ -353,17 +322,6 @@ export const buildLinearImageToImageGraph = (
init_image: initialImage.imageName,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add LoRA support
addLoRAsToGraph(state, graph, LATENTS_TO_LATENTS);
@ -376,5 +334,16 @@ export const buildLinearImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, LATENTS_TO_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -0,0 +1,382 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import {
ImageResizeInvocation,
ImageToLatentsInvocation,
} from 'services/api/types';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import {
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RESIZE,
SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_LATENTS_TO_LATENTS,
SDXL_MODEL_LOADER,
} from './constants';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
/**
* Builds the Image to Image tab graph.
*/
export const buildLinearSDXLImageToImageGraph = (
state: RootState
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
initialImage,
shouldFitToWidthHeight,
width,
height,
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
} = state.generation;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
sdxlImg2ImgDenoisingStrength: strength,
} = state.sdxl;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
if (!initialImage) {
log.error('No initial image found in state');
throw new Error('No initial image found in state');
}
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: SDXL_IMAGE_TO_IMAGE_GRAPH,
nodes: {
[SDXL_MODEL_LOADER]: {
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
},
[SDXL_LATENTS_TO_LATENTS]: {
type: 'l2l_sdxl',
id: SDXL_LATENTS_TO_LATENTS,
cfg_scale,
scheduler,
steps,
denoising_start: shouldUseSDXLRefiner
? Math.min(refinerStart, 1 - strength)
: 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
id: IMAGE_TO_LATENTS,
// must be set manually later, bc `fit` parameter may require a resize node inserted
// image: {
// image_name: initialImage.image_name,
// },
fp32: vaePrecision === 'fp32' ? true : false,
},
},
edges: [
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
{
source: {
node_id: IMAGE_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'latents',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'negative_conditioning',
},
},
],
};
// handle `fit`
if (
shouldFitToWidthHeight &&
(initialImage.width !== width || initialImage.height !== height)
) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image
const resizeNode: ImageResizeInvocation = {
id: RESIZE,
type: 'img_resize',
image: {
image_name: initialImage.imageName,
},
is_intermediate: true,
width,
height,
};
graph.nodes[RESIZE] = resizeNode;
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
graph.edges.push({
source: { node_id: RESIZE, field: 'image' },
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
// The `RESIZE` node also passes its width and height to `NOISE`
graph.edges.push({
source: { node_id: RESIZE, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: RESIZE, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.imageName,
};
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
}
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'sdxl_img2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined,
controlnets: [],
loras: [],
clip_skip: clipSkip,
strength: strength,
init_image: initialImage.imageName,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS);
}
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -0,0 +1,264 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
SDXL_MODEL_LOADER,
SDXL_TEXT_TO_IMAGE_GRAPH,
SDXL_TEXT_TO_LATENTS,
} from './constants';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
export const buildLinearSDXLTextToImageGraph = (
state: RootState
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
width,
height,
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
} = state.generation;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
} = state.sdxl;
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: SDXL_TEXT_TO_IMAGE_GRAPH,
nodes: {
[SDXL_MODEL_LOADER]: {
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
width,
height,
use_cpu,
},
[SDXL_TEXT_TO_LATENTS]: {
type: 't2l_sdxl',
id: SDXL_TEXT_TO_LATENTS,
cfg_scale,
scheduler,
steps,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
},
},
edges: [
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
],
};
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'sdxl_txt2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined,
controlnets: [],
loras: [],
clip_skip: clipSkip,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);
}
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -5,7 +5,9 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CLIP_SKIP,
LATENTS_TO_IMAGE,
@ -34,6 +36,7 @@ export const buildLinearTextToImageGraph = (
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
} = state.generation;
const use_cpu = shouldUseNoiseSettings
@ -95,6 +98,7 @@ export const buildLinearTextToImageGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
},
},
edges: [
@ -202,17 +206,6 @@ export const buildLinearTextToImageGraph = (
clip_skip: clipSkip,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// add LoRA support
addLoRAsToGraph(state, graph, TEXT_TO_LATENTS);
@ -225,5 +218,16 @@ export const buildLinearTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, TEXT_TO_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@ -3,6 +3,8 @@ export const POSITIVE_CONDITIONING = 'positive_conditioning';
export const NEGATIVE_CONDITIONING = 'negative_conditioning';
export const TEXT_TO_LATENTS = 'text_to_latents';
export const LATENTS_TO_IMAGE = 'latents_to_image';
export const NSFW_CHECKER = 'nsfw_checker';
export const WATERMARKER = 'invisible_watermark';
export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size';
@ -23,8 +25,19 @@ export const METADATA_ACCUMULATOR = 'metadata_accumulator';
export const REALESRGAN = 'esrgan';
export const DIVIDE = 'divide';
export const SCALE = 'scale_image';
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
export const SDXL_TEXT_TO_LATENTS = 't2l_sdxl';
export const SDXL_LATENTS_TO_LATENTS = 'l2l_sdxl';
export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader';
export const SDXL_REFINER_POSITIVE_CONDITIONING =
'sdxl_refiner_positive_conditioning';
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_LATENTS_TO_LATENTS = 'l2l_sdxl_refiner';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
export const SDXL_TEXT_TO_IMAGE_GRAPH = 'sdxl_text_to_image_graph';
export const SDXL_IMAGE_TO_IMAGE_GRAPH = 'sxdl_image_to_image_graph';
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
export const INPAINT_GRAPH = 'inpaint_graph';

View File

@ -13,7 +13,12 @@ import {
buildOutputFieldTemplates,
} from './fieldTemplateBuilders';
const RESERVED_FIELD_NAMES = ['id', 'type', 'is_intermediate', 'metadata'];
const getReservedFieldNames = (type: string): string[] => {
if (type === 'l2i') {
return ['id', 'type', 'metadata'];
}
return ['id', 'type', 'is_intermediate', 'metadata'];
};
const invocationDenylist = [
'Graph',
@ -21,11 +26,11 @@ const invocationDenylist = [
'MetadataAccumulatorInvocation',
];
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now
export const parseSchema = (
openAPI: OpenAPIV3.Document
): Record<string, InvocationTemplate> => {
const filteredSchemas = filter(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
openAPI.components!.schemas,
openAPI.components?.schemas,
(schema, key) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
@ -35,21 +40,17 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate>
>((acc, schema) => {
// only want SchemaObjects
if (isInvocationSchemaObject(schema)) {
const type = schema.properties.type.default;
const RESERVED_FIELD_NAMES = getReservedFieldNames(type);
const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
const typeHints = schema.ui?.type_hints;
const inputs: Record<string, InputFieldTemplate> = {};
if (type === 'collect') {
const itemProperty = schema.properties[
'item'
] as InvocationSchemaObject;
// Handle the special Collect node
const itemProperty = schema.properties.item as InvocationSchemaObject;
inputs.item = {
type: 'item',
name: 'item',
@ -60,10 +61,8 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
default: undefined,
};
} else if (type === 'iterate') {
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
const itemProperty = schema.properties
.collection as InvocationSchemaObject;
inputs.collection = {
type: 'array',
name: 'collection',
@ -74,18 +73,18 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
inputKind: 'connection',
};
} else {
// All other nodes
reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!RESERVED_FIELD_NAMES.includes(propertyName) &&
isSchemaObject(property)
) {
const field: InputFieldTemplate | undefined =
buildInputFieldTemplate(property, propertyName, typeHints);
const field = buildInputFieldTemplate(
property,
propertyName,
typeHints
);
if (field) {
inputsAccumulator[propertyName] = field;
}
@ -97,22 +96,17 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
}
const rawOutput = (schema as InvocationSchemaObject).output;
let outputs: Record<string, OutputFieldTemplate>;
// some special handling is needed for collect, iterate and range nodes
if (type === 'iterate') {
// this is guaranteed to be a SchemaObject
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const iterationOutput = openAPI.components!.schemas![
const iterationOutput = openAPI.components?.schemas?.[
'IterateInvocationOutput'
] as OpenAPIV3.SchemaObject;
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
title: iterationOutput?.title ?? '',
description: iterationOutput?.description ?? '',
type: 'array',
},
};

View File

@ -7,9 +7,9 @@ import { activeTabNameSelector } from '../../../../ui/store/uiSelectors';
const aspectRatios = [
{ name: 'Free', value: null },
{ name: 'Portrait', value: 0.67 / 1 },
{ name: 'Wide', value: 16 / 9 },
{ name: 'Square', value: 1 / 1 },
{ name: '2:3', value: 2 / 3 },
{ name: '16:9', value: 16 / 9 },
{ name: '1:1', value: 1 / 1 },
];
export default function ParamAspectRatio() {

View File

@ -4,6 +4,7 @@ import { memo } from 'react';
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
import ParamScheduler from './ParamScheduler';
import ParamVAEPrecision from '../VAEModel/ParamVAEPrecision';
const ParamModelandVAEandScheduler = () => {
const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled;
@ -13,16 +14,15 @@ const ParamModelandVAEandScheduler = () => {
<Box w="full">
<ParamMainModelSelect />
</Box>
<Flex gap={3} w="full">
{isVaeEnabled && (
<Box w="full">
<ParamVAEModelSelect />
</Box>
)}
<Box w="full">
<ParamScheduler />
</Box>
</Flex>
<Box w="full">
<ParamScheduler />
</Box>
{isVaeEnabled && (
<Flex w="full" gap={3}>
<ParamVAEModelSelect />
<ParamVAEPrecision />
</Flex>
)}
</Flex>
);
};

View File

@ -13,7 +13,9 @@ import { modelSelected } from 'features/parameters/store/actions';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
@ -29,8 +31,12 @@ const ParamMainModelSelect = () => {
const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const activeTabName = useAppSelector(activeTabNameSelector);
const data = useMemo(() => {
if (!mainModels) {
@ -40,7 +46,10 @@ const ParamMainModelSelect = () => {
const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => {
if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) {
if (
!model ||
(activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl')
) {
return;
}
@ -52,7 +61,7 @@ const ParamMainModelSelect = () => {
});
return data;
}, [mainModels]);
}, [mainModels, activeTabName]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
@ -88,7 +97,7 @@ const ParamMainModelSelect = () => {
data={[]}
/>
) : (
<Flex w="100%" alignItems="center" gap={2}>
<Flex w="100%" alignItems="center" gap={3}>
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={t('modelManager.model')}

View File

@ -32,11 +32,6 @@ export default function ParamSeed() {
isInvalid={seed < 0 && shouldGenerateVariations}
onChange={handleChangeSeed}
value={seed}
formControlProps={{
display: 'flex',
alignItems: 'center',
gap: 3, // really this should work with 2 but seems to need to be 3 to match gap 2?
}}
/>
);
}

View File

@ -6,7 +6,7 @@ import ParamSeedRandomize from './ParamSeedRandomize';
const ParamSeedFull = () => {
return (
<Flex sx={{ gap: 4, alignItems: 'center' }}>
<Flex sx={{ gap: 3, alignItems: 'flex-end' }}>
<ParamSeed />
<ParamSeedShuffle />
<ParamSeedRandomize />

View File

@ -0,0 +1,46 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { vaePrecisionChanged } from 'features/parameters/store/generationSlice';
import { PrecisionParam } from 'features/parameters/types/parameterSchemas';
import { memo, useCallback } from 'react';
const selector = createSelector(
stateSelector,
({ generation }) => {
const { vaePrecision } = generation;
return { vaePrecision };
},
defaultSelectorOptions
);
const DATA = ['fp16', 'fp32'];
const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
const { vaePrecision } = useAppSelector(selector);
const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(vaePrecisionChanged(v as PrecisionParam));
},
[dispatch]
);
return (
<IAIMantineSelect
label="VAE Precision"
value={vaePrecision}
data={DATA}
onChange={handleChange}
/>
);
};
export default memo(ParamVAEModelSelect);

View File

@ -11,6 +11,7 @@ import {
MainModelParam,
NegativePromptParam,
PositivePromptParam,
PrecisionParam,
SchedulerParam,
SeedParam,
StepsParam,
@ -51,6 +52,7 @@ export interface GenerationState {
verticalSymmetrySteps: number;
model: MainModelField | null;
vae: VaeModelParam | null;
vaePrecision: PrecisionParam;
seamlessXAxis: boolean;
seamlessYAxis: boolean;
clipSkip: number;
@ -89,6 +91,7 @@ export const initialGenerationState: GenerationState = {
verticalSymmetrySteps: 0,
model: null,
vae: null,
vaePrecision: 'fp32',
seamlessXAxis: false,
seamlessYAxis: false,
clipSkip: 0,
@ -241,6 +244,9 @@ export const generationSlice = createSlice({
// null is a valid VAE!
state.vae = action.payload;
},
vaePrecisionChanged: (state, action: PayloadAction<PrecisionParam>) => {
state.vaePrecision = action.payload;
},
setClipSkip: (state, action: PayloadAction<number>) => {
state.clipSkip = action.payload;
},
@ -327,6 +333,7 @@ export const {
shouldUseCpuNoiseChanged,
setShouldShowAdvancedOptions,
setAspectRatio,
vaePrecisionChanged,
} = generationSlice.actions;
export default generationSlice.reducer;

View File

@ -42,6 +42,42 @@ export const isValidNegativePrompt = (
val: unknown
): val is NegativePromptParam => zNegativePrompt.safeParse(val).success;
/**
* Zod schema for SDXL positive style prompt parameter
*/
export const zPositiveStylePromptSDXL = z.string();
/**
* Type alias for SDXL positive style prompt parameter, inferred from its zod schema
*/
export type PositiveStylePromptSDXLParam = z.infer<
typeof zPositiveStylePromptSDXL
>;
/**
* Validates/type-guards a value as a SDXL positive style prompt parameter
*/
export const isValidSDXLPositiveStylePrompt = (
val: unknown
): val is PositiveStylePromptSDXLParam =>
zPositiveStylePromptSDXL.safeParse(val).success;
/**
* Zod schema for SDXL negative style prompt parameter
*/
export const zNegativeStylePromptSDXL = z.string();
/**
* Type alias for SDXL negative style prompt parameter, inferred from its zod schema
*/
export type NegativeStylePromptSDXLParam = z.infer<
typeof zNegativeStylePromptSDXL
>;
/**
* Validates/type-guards a value as a SDXL negative style prompt parameter
*/
export const isValidSDXLNegativeStylePrompt = (
val: unknown
): val is NegativeStylePromptSDXLParam =>
zNegativeStylePromptSDXL.safeParse(val).success;
/**
* Zod schema for steps parameter
*/
@ -260,6 +296,20 @@ export type StrengthParam = z.infer<typeof zStrength>;
export const isValidStrength = (val: unknown): val is StrengthParam =>
zStrength.safeParse(val).success;
/**
* Zod schema for a precision parameter
*/
export const zPrecision = z.enum(['fp16', 'fp32']);
/**
* Type alias for precision parameter, inferred from its zod schema
*/
export type PrecisionParam = z.infer<typeof zPrecision>;
/**
* Validates/type-guards a value as a precision parameter
*/
export const isValidPrecision = (val: unknown): val is PrecisionParam =>
zPrecision.safeParse(val).success;
// /**
// * Zod schema for BaseModelType
// */

View File

@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { setSDXLImg2ImgDenoisingStrength } from '../store/sdxlSlice';
const selector = createSelector(
[stateSelector],
({ sdxl }) => {
const { sdxlImg2ImgDenoisingStrength } = sdxl;
return {
sdxlImg2ImgDenoisingStrength,
};
},
defaultSelectorOptions
);
const ParamSDXLImg2ImgDenoisingStrength = () => {
const { sdxlImg2ImgDenoisingStrength } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => dispatch(setSDXLImg2ImgDenoisingStrength(v)),
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(setSDXLImg2ImgDenoisingStrength(0.7));
}, [dispatch]);
return (
<IAISlider
label={`${t('parameters.denoisingStrength')}`}
step={0.01}
min={0}
max={1}
onChange={handleChange}
handleReset={handleReset}
value={sdxlImg2ImgDenoisingStrength}
isInteger={false}
withInput
withSliderMarks
withReset
/>
);
};
export default memo(ParamSDXLImg2ImgDenoisingStrength);

View File

@ -0,0 +1,149 @@
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { setNegativeStylePromptSDXL } from '../store/sdxlSlice';
const promptInputSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ sdxl }, activeTabName) => {
return {
prompt: sdxl.negativeStylePrompt,
activeTabName,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Prompt input text area.
*/
const ParamSDXLNegativeStyleConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setNegativeStylePromptSDXL(e.target.value));
},
[dispatch]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setNegativeStylePromptSDXL(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (isEmbeddingEnabled && e.key === '<') {
onOpen();
}
},
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box position="relative">
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder="Negative Style Prompt"
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
fontSize="sm"
minH={16}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};
export default ParamSDXLNegativeStyleConditioning;

View File

@ -0,0 +1,148 @@
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { setPositiveStylePromptSDXL } from '../store/sdxlSlice';
const promptInputSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ sdxl }, activeTabName) => {
return {
prompt: sdxl.positiveStylePrompt,
activeTabName,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Prompt input text area.
*/
const ParamSDXLPositiveStyleConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPositiveStylePromptSDXL(e.target.value));
},
[dispatch]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setPositiveStylePromptSDXL(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (isEmbeddingEnabled && e.key === '<') {
onOpen();
}
},
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box position="relative">
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder="Positive Style Prompt"
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
minH={16}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};
export default ParamSDXLPositiveStyleConditioning;

View File

@ -0,0 +1,48 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import ParamSDXLRefinerAestheticScore from './SDXLRefiner/ParamSDXLRefinerAestheticScore';
import ParamSDXLRefinerCFGScale from './SDXLRefiner/ParamSDXLRefinerCFGScale';
import ParamSDXLRefinerModelSelect from './SDXLRefiner/ParamSDXLRefinerModelSelect';
import ParamSDXLRefinerScheduler from './SDXLRefiner/ParamSDXLRefinerScheduler';
import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
const selector = createSelector(
stateSelector,
(state) => {
const { shouldUseSDXLRefiner } = state.sdxl;
const { shouldUseSliders } = state.ui;
return {
activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined,
shouldUseSliders,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerCollapse = () => {
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
return (
<IAICollapse label="Refiner" activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamUseSDXLRefiner />
<ParamSDXLRefinerModelSelect />
<Flex gap={2} flexDirection={shouldUseSliders ? 'column' : 'row'}>
<ParamSDXLRefinerSteps />
<ParamSDXLRefinerCFGScale />
</Flex>
<ParamSDXLRefinerScheduler />
<ParamSDXLRefinerAestheticScore />
<ParamSDXLRefinerStart />
</Flex>
</IAICollapse>
);
};
export default ParamSDXLRefinerCollapse;

View File

@ -0,0 +1,78 @@
import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamModelandVAEandScheduler from 'features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler';
import ParamSize from 'features/parameters/components/Parameters/Core/ParamSize';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
import ParamSDXLImg2ImgDenoisingStrength from './ParamSDXLImg2ImgDenoisingStrength';
const selector = createSelector(
[uiSelector, generationSelector],
(ui, generation) => {
const { shouldUseSliders } = ui;
const { shouldRandomizeSeed } = generation;
const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined;
return { shouldUseSliders, activeLabel };
},
defaultSelectorOptions
);
const SDXLImageToImageTabCoreParameters = () => {
const { shouldUseSliders, activeLabel } = useAppSelector(selector);
return (
<IAICollapse
label={'General'}
activeLabel={activeLabel}
defaultIsOpen={true}
>
<Flex
sx={{
flexDirection: 'column',
gap: 3,
}}
>
{shouldUseSliders ? (
<>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
<ParamModelandVAEandScheduler />
<Box pt={2}>
<ParamSeedFull />
</Box>
<ParamSize />
</>
) : (
<>
<Flex gap={3}>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
</Flex>
<ParamModelandVAEandScheduler />
<Box pt={2}>
<ParamSeedFull />
</Box>
<ParamSize />
</>
)}
<ParamSDXLImg2ImgDenoisingStrength />
<ImageToImageFit />
</Flex>
</IAICollapse>
);
};
export default memo(SDXLImageToImageTabCoreParameters);

View File

@ -0,0 +1,28 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
const SDXLImageToImageTabParameters = () => {
return (
<>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ProcessButtons />
<SDXLImageToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
);
};
export default SDXLImageToImageTabParameters;

View File

@ -0,0 +1,60 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl, hotkeys }) => {
const { refinerAestheticScore } = sdxl;
const { shift } = hotkeys;
return {
refinerAestheticScore,
shift,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerAestheticScore = () => {
const { refinerAestheticScore, shift } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerAestheticScore(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setRefinerAestheticScore(6)),
[dispatch]
);
return (
<IAISlider
label="Aesthetic Score"
step={shift ? 0.1 : 0.5}
min={1}
max={10}
onChange={handleChange}
handleReset={handleReset}
value={refinerAestheticScore}
sliderNumberInputProps={{ max: 10 }}
withInput
withReset
withSliderMarks
isInteger={false}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerAestheticScore);

View File

@ -0,0 +1,75 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl, ui, hotkeys }) => {
const { refinerCFGScale } = sdxl;
const { shouldUseSliders } = ui;
const { shift } = hotkeys;
return {
refinerCFGScale,
shouldUseSliders,
shift,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerCFGScale = () => {
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerCFGScale(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setRefinerCFGScale(7)),
[dispatch]
);
return shouldUseSliders ? (
<IAISlider
label={t('parameters.cfgScale')}
step={shift ? 0.1 : 0.5}
min={1}
max={20}
onChange={handleChange}
handleReset={handleReset}
value={refinerCFGScale}
sliderNumberInputProps={{ max: 200 }}
withInput
withReset
withSliderMarks
isInteger={false}
isDisabled={!isRefinerAvailable}
/>
) : (
<IAINumberInput
label={t('parameters.cfgScale')}
step={0.5}
min={1}
max={200}
onChange={handleChange}
value={refinerCFGScale}
isInteger={false}
numberInputFieldProps={{ textAlign: 'center' }}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerCFGScale);

View File

@ -0,0 +1,111 @@
import { Box, Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector(
stateSelector,
(state) => ({ model: state.sdxl.refinerModel }),
defaultSelectorOptions
);
const ParamSDXLRefinerModelSelect = () => {
const dispatch = useAppDispatch();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { model } = useAppSelector(selector);
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {
return [];
}
const data: SelectItem[] = [];
forEach(refinerModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [refinerModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
refinerModels?.entities[
`${model?.base_model}/main/${model?.model_name}`
] ?? null,
[refinerModels?.entities, model]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch(refinerModelChanged(newModel));
},
[dispatch]
);
return isLoading ? (
<IAIMantineSearchableSelect
label="Refiner Model"
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<Flex w="100%" alignItems="center" gap={2}>
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label="Refiner Model"
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
w="100%"
/>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton iconMode />
</Box>
)}
</Flex>
);
};
export default memo(ParamSDXLRefinerModelSelect);

View File

@ -0,0 +1,65 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
stateSelector,
({ ui, sdxl }) => {
const { refinerScheduler } = sdxl;
const { favoriteSchedulers: enabledSchedulers } = ui;
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
value: name,
label: label,
group: enabledSchedulers.includes(name as SchedulerParam)
? 'Favorites'
: undefined,
})).sort((a, b) => a.label.localeCompare(b.label));
return {
refinerScheduler,
data,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerScheduler = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { refinerScheduler, data } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(setRefinerScheduler(v as SchedulerParam));
},
[dispatch]
);
return (
<IAIMantineSearchableSelect
w="100%"
label={t('parameters.scheduler')}
value={refinerScheduler}
data={data}
onChange={handleChange}
disabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerScheduler);

View File

@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl }) => {
const { refinerStart } = sdxl;
return {
refinerStart,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerStart = () => {
const { refinerStart } = useAppSelector(selector);
const dispatch = useAppDispatch();
const isRefinerAvailable = useIsRefinerAvailable();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerStart(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setRefinerStart(0.7)),
[dispatch]
);
return (
<IAISlider
label="Refiner Start"
step={0.01}
min={0}
max={1}
onChange={handleChange}
handleReset={handleReset}
value={refinerStart}
withInput
withReset
withSliderMarks
isInteger={false}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerStart);

View File

@ -0,0 +1,72 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl, ui }) => {
const { refinerSteps } = sdxl;
const { shouldUseSliders } = ui;
return {
refinerSteps,
shouldUseSliders,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerSteps = () => {
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => {
dispatch(setRefinerSteps(v));
},
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(setRefinerSteps(20));
}, [dispatch]);
return shouldUseSliders ? (
<IAISlider
label={t('parameters.steps')}
min={1}
max={100}
step={1}
onChange={handleChange}
handleReset={handleReset}
value={refinerSteps}
withInput
withReset
withSliderMarks
sliderNumberInputProps={{ max: 500 }}
isDisabled={!isRefinerAvailable}
/>
) : (
<IAINumberInput
label={t('parameters.steps')}
min={1}
max={500}
step={1}
onChange={handleChange}
value={refinerSteps}
numberInputFieldProps={{ textAlign: 'center' }}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerSteps);

View File

@ -0,0 +1,28 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
import { ChangeEvent } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
export default function ParamUseSDXLRefiner() {
const shouldUseSDXLRefiner = useAppSelector(
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const handleUseSDXLRefinerChange = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(setShouldUseSDXLRefiner(e.target.checked));
};
return (
<IAISwitch
label="Use Refiner"
isChecked={shouldUseSDXLRefiner}
onChange={handleUseSDXLRefinerChange}
isDisabled={!isRefinerAvailable}
/>
);
}

View File

@ -0,0 +1,27 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
const SDXLTextToImageTabParameters = () => {
return (
<>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
);
};
export default SDXLTextToImageTabParameters;

View File

@ -0,0 +1,89 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import {
MainModelParam,
NegativeStylePromptSDXLParam,
PositiveStylePromptSDXLParam,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { MainModelField } from 'services/api/types';
type SDXLInitialState = {
positiveStylePrompt: PositiveStylePromptSDXLParam;
negativeStylePrompt: NegativeStylePromptSDXLParam;
shouldUseSDXLRefiner: boolean;
sdxlImg2ImgDenoisingStrength: number;
refinerModel: MainModelField | null;
refinerSteps: number;
refinerCFGScale: number;
refinerScheduler: SchedulerParam;
refinerAestheticScore: number;
refinerStart: number;
};
const sdxlInitialState: SDXLInitialState = {
positiveStylePrompt: '',
negativeStylePrompt: '',
shouldUseSDXLRefiner: false,
sdxlImg2ImgDenoisingStrength: 0.7,
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler',
refinerAestheticScore: 6,
refinerStart: 0.7,
};
const sdxlSlice = createSlice({
name: 'sdxl',
initialState: sdxlInitialState,
reducers: {
setPositiveStylePromptSDXL: (state, action: PayloadAction<string>) => {
state.positiveStylePrompt = action.payload;
},
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
state.negativeStylePrompt = action.payload;
},
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
state.shouldUseSDXLRefiner = action.payload;
},
setSDXLImg2ImgDenoisingStrength: (state, action: PayloadAction<number>) => {
state.sdxlImg2ImgDenoisingStrength = action.payload;
},
refinerModelChanged: (
state,
action: PayloadAction<MainModelParam | null>
) => {
state.refinerModel = action.payload;
},
setRefinerSteps: (state, action: PayloadAction<number>) => {
state.refinerSteps = action.payload;
},
setRefinerCFGScale: (state, action: PayloadAction<number>) => {
state.refinerCFGScale = action.payload;
},
setRefinerScheduler: (state, action: PayloadAction<SchedulerParam>) => {
state.refinerScheduler = action.payload;
},
setRefinerAestheticScore: (state, action: PayloadAction<number>) => {
state.refinerAestheticScore = action.payload;
},
setRefinerStart: (state, action: PayloadAction<number>) => {
state.refinerStart = action.payload;
},
},
});
export const {
setPositiveStylePromptSDXL,
setNegativeStylePromptSDXL,
setShouldUseSDXLRefiner,
setSDXLImg2ImgDenoisingStrength,
refinerModelChanged,
setRefinerSteps,
setRefinerCFGScale,
setRefinerScheduler,
setRefinerAestheticScore,
setRefinerStart,
} = sdxlSlice.actions;
export default sdxlSlice.reducer;

Some files were not shown because too many files have changed in this diff Show More