mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add SDXL To Linear UI (#3973)
## What type of PR is this? (check all applicable) - [x] Feature ## Have you discussed this change with the InvokeAI team? - [x] Yes ## Description This PR adds support for SDXL Models in the Linear UI ### DONE - SDXL Base Text To Image Support - SDXL Base Image To Image Support - SDXL Refiner Support - SDXL Relevant UI ## [optional] Are there any post deployment tasks we need to perform? Double check to ensure nothing major changed with 1.0 -- In any case those changes would be backend related mostly. If Refiner is scrapped for 1.0 models, then we simply disable the Refiner Graph.
This commit is contained in:
commit
531bc40d3f
@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"}), context=context)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation):
|
|||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(), context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(),
|
**clip_field.text_encoder.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"}), context=context)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
@ -196,6 +196,7 @@ class SDXLPromptInvocationBase:
|
|||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
@ -240,16 +241,16 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(), context=context,
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(),
|
**clip_field.text_encoder.dict(), context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}))
|
**lora.dict(exclude={"weight"}), context=context)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
@ -265,6 +266,7 @@ class SDXLPromptInvocationBase:
|
|||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
|
@ -2,16 +2,19 @@ from typing import Literal, Optional, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocationOutput, InvocationConfig,
|
BaseInvocation,
|
||||||
InvocationContext)
|
BaseInvocationOutput,
|
||||||
|
InvocationConfig,
|
||||||
|
InvocationContext,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
VAEModelField)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModel):
|
class LoRAMetadataField(BaseModel):
|
||||||
"""LoRA metadata for an image generated in InvokeAI."""
|
"""LoRA metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
lora: LoRAModelField = Field(description="The LoRA model")
|
lora: LoRAModelField = Field(description="The LoRA model")
|
||||||
weight: float = Field(description="The weight of the LoRA model")
|
weight: float = Field(description="The weight of the LoRA model")
|
||||||
|
|
||||||
@ -19,7 +22,9 @@ class LoRAMetadataField(BaseModel):
|
|||||||
class CoreMetadata(BaseModel):
|
class CoreMetadata(BaseModel):
|
||||||
"""Core generation metadata for an image generated in InvokeAI."""
|
"""Core generation metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
generation_mode: str = Field(description="The generation mode that output this image",)
|
generation_mode: str = Field(
|
||||||
|
description="The generation mode that output this image",
|
||||||
|
)
|
||||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||||
width: int = Field(description="The width parameter")
|
width: int = Field(description="The width parameter")
|
||||||
@ -29,10 +34,20 @@ class CoreMetadata(BaseModel):
|
|||||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||||
steps: int = Field(description="The number of steps used for inference")
|
steps: int = Field(description="The number of steps used for inference")
|
||||||
scheduler: str = Field(description="The scheduler used for inference")
|
scheduler: str = Field(description="The scheduler used for inference")
|
||||||
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
clip_skip: int = Field(
|
||||||
|
description="The number of skipped CLIP layers",
|
||||||
|
)
|
||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = Field(
|
||||||
|
description="The ControlNets used for inference"
|
||||||
|
)
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
|
vae: Union[VAEModelField, None] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Latents-to-Latents
|
||||||
strength: Union[float, None] = Field(
|
strength: Union[float, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The strength used for latents-to-latents",
|
description="The strength used for latents-to-latents",
|
||||||
@ -40,9 +55,34 @@ class CoreMetadata(BaseModel):
|
|||||||
init_image: Union[str, None] = Field(
|
init_image: Union[str, None] = Field(
|
||||||
default=None, description="The name of the initial image"
|
default=None, description="The name of the initial image"
|
||||||
)
|
)
|
||||||
vae: Union[VAEModelField, None] = Field(
|
|
||||||
|
# SDXL
|
||||||
|
positive_style_prompt: Union[str, None] = Field(
|
||||||
|
default=None, description="The positive style prompt parameter"
|
||||||
|
)
|
||||||
|
negative_style_prompt: Union[str, None] = Field(
|
||||||
|
default=None, description="The negative style prompt parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
# SDXL Refiner
|
||||||
|
refiner_model: Union[MainModelField, None] = Field(
|
||||||
|
default=None, description="The SDXL Refiner model used"
|
||||||
|
)
|
||||||
|
refiner_cfg_scale: Union[float, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
|
)
|
||||||
|
refiner_steps: Union[int, None] = Field(
|
||||||
|
default=None, description="The number of steps used for the refiner"
|
||||||
|
)
|
||||||
|
refiner_scheduler: Union[str, None] = Field(
|
||||||
|
default=None, description="The scheduler used for the refiner"
|
||||||
|
)
|
||||||
|
refiner_aesthetic_store: Union[float, None] = Field(
|
||||||
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
|
)
|
||||||
|
refiner_start: Union[float, None] = Field(
|
||||||
|
default=None, description="The start value used for refiner denoising"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -71,7 +111,9 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||||
|
|
||||||
generation_mode: str = Field(description="The generation mode that output this image",)
|
generation_mode: str = Field(
|
||||||
|
description="The generation mode that output this image",
|
||||||
|
)
|
||||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||||
width: int = Field(description="The width parameter")
|
width: int = Field(description="The width parameter")
|
||||||
@ -81,9 +123,13 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||||
steps: int = Field(description="The number of steps used for inference")
|
steps: int = Field(description="The number of steps used for inference")
|
||||||
scheduler: str = Field(description="The scheduler used for inference")
|
scheduler: str = Field(description="The scheduler used for inference")
|
||||||
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
clip_skip: int = Field(
|
||||||
|
description="The number of skipped CLIP layers",
|
||||||
|
)
|
||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = Field(
|
||||||
|
description="The ControlNets used for inference"
|
||||||
|
)
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
strength: Union[float, None] = Field(
|
strength: Union[float, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -97,36 +143,44 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# SDXL
|
||||||
|
positive_style_prompt: Union[str, None] = Field(
|
||||||
|
default=None, description="The positive style prompt parameter"
|
||||||
|
)
|
||||||
|
negative_style_prompt: Union[str, None] = Field(
|
||||||
|
default=None, description="The negative style prompt parameter"
|
||||||
|
)
|
||||||
|
|
||||||
|
# SDXL Refiner
|
||||||
|
refiner_model: Union[MainModelField, None] = Field(
|
||||||
|
default=None, description="The SDXL Refiner model used"
|
||||||
|
)
|
||||||
|
refiner_cfg_scale: Union[float, None] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
|
)
|
||||||
|
refiner_steps: Union[int, None] = Field(
|
||||||
|
default=None, description="The number of steps used for the refiner"
|
||||||
|
)
|
||||||
|
refiner_scheduler: Union[str, None] = Field(
|
||||||
|
default=None, description="The scheduler used for the refiner"
|
||||||
|
)
|
||||||
|
refiner_aesthetic_store: Union[float, None] = Field(
|
||||||
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
|
)
|
||||||
|
refiner_start: Union[float, None] = Field(
|
||||||
|
default=None, description="The start value used for refiner denoising"
|
||||||
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "Metadata Accumulator",
|
"title": "Metadata Accumulator",
|
||||||
"tags": ["image", "metadata", "generation"]
|
"tags": ["image", "metadata", "generation"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||||
"""Collects and outputs a CoreMetadata object"""
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
|
||||||
return MetadataAccumulatorOutput(
|
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
||||||
metadata=CoreMetadata(
|
|
||||||
generation_mode=self.generation_mode,
|
|
||||||
positive_prompt=self.positive_prompt,
|
|
||||||
negative_prompt=self.negative_prompt,
|
|
||||||
width=self.width,
|
|
||||||
height=self.height,
|
|
||||||
seed=self.seed,
|
|
||||||
rand_device=self.rand_device,
|
|
||||||
cfg_scale=self.cfg_scale,
|
|
||||||
steps=self.steps,
|
|
||||||
scheduler=self.scheduler,
|
|
||||||
model=self.model,
|
|
||||||
strength=self.strength,
|
|
||||||
init_image=self.init_image,
|
|
||||||
vae=self.vae,
|
|
||||||
controlnets=self.controlnets,
|
|
||||||
loras=self.loras,
|
|
||||||
clip_skip=self.clip_skip,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
@ -138,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"title": "SDXL Refiner Model Loader",
|
"title": "SDXL Refiner Model Loader",
|
||||||
"tags": ["model", "loader", "sdxl_refiner"],
|
"tags": ["model", "loader", "sdxl_refiner"],
|
||||||
"type_hints": {"model": "model"},
|
"type_hints": {"model": "refiner_model"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict()
|
**self.unet.unet.dict(), context=context
|
||||||
)
|
)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
@ -463,8 +463,8 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
latents: Optional[LatentsField] = Field(description="Initial latents")
|
latents: Optional[LatentsField] = Field(description="Initial latents")
|
||||||
|
|
||||||
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
|
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||||
|
|
||||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
@ -549,13 +549,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
num_inference_steps = num_inference_steps - t_start
|
num_inference_steps = num_inference_steps - t_start
|
||||||
|
|
||||||
# apply noise(if provided)
|
# apply noise(if provided)
|
||||||
if self.noise is not None:
|
if self.noise is not None and timesteps.shape[0] > 0:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
||||||
del noise
|
del noise
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict()
|
**self.unet.unet.dict(), context=context,
|
||||||
)
|
)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
|
@ -65,18 +65,19 @@ import { addGeneratorProgressEventListener as addGeneratorProgressListener } fro
|
|||||||
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
|
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
|
||||||
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
|
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
|
||||||
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
|
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
|
||||||
|
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
|
||||||
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
||||||
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
|
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
|
||||||
|
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
|
||||||
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||||
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||||
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||||
|
import { addTabChangedListener } from './listeners/tabChanged';
|
||||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||||
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
|
|
||||||
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
|
|
||||||
|
|
||||||
export const listenerMiddleware = createListenerMiddleware();
|
export const listenerMiddleware = createListenerMiddleware();
|
||||||
|
|
||||||
@ -201,3 +202,6 @@ addFirstListImagesListener();
|
|||||||
|
|
||||||
// Ad-hoc upscale workflwo
|
// Ad-hoc upscale workflwo
|
||||||
addUpscaleRequestedListener();
|
addUpscaleRequestedListener();
|
||||||
|
|
||||||
|
// Tab Change
|
||||||
|
addTabChangedListener();
|
||||||
|
@ -9,13 +9,19 @@ import {
|
|||||||
zMainModel,
|
zMainModel,
|
||||||
zVaeModel,
|
zVaeModel,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import {
|
||||||
|
refinerModelChanged,
|
||||||
|
setShouldUseSDXLRefiner,
|
||||||
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import { forEach, some } from 'lodash-es';
|
import { forEach, some } from 'lodash-es';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
export const addModelsLoadedListener = () => {
|
export const addModelsLoadedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
|
predicate: (state, action) =>
|
||||||
|
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||||
|
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||||
const log = logger('models');
|
const log = logger('models');
|
||||||
@ -59,6 +65,54 @@ export const addModelsLoadedListener = () => {
|
|||||||
dispatch(modelChanged(result.data));
|
dispatch(modelChanged(result.data));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
startAppListening({
|
||||||
|
predicate: (state, action) =>
|
||||||
|
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||||
|
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||||
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||||
|
const log = logger('models');
|
||||||
|
log.info(
|
||||||
|
{ models: action.payload.entities },
|
||||||
|
`SDXL Refiner models loaded (${action.payload.ids.length})`
|
||||||
|
);
|
||||||
|
|
||||||
|
const currentModel = getState().sdxl.refinerModel;
|
||||||
|
|
||||||
|
const isCurrentModelAvailable = some(
|
||||||
|
action.payload.entities,
|
||||||
|
(m) =>
|
||||||
|
m?.model_name === currentModel?.model_name &&
|
||||||
|
m?.base_model === currentModel?.base_model
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isCurrentModelAvailable) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstModelId = action.payload.ids[0];
|
||||||
|
const firstModel = action.payload.entities[firstModelId];
|
||||||
|
|
||||||
|
if (!firstModel) {
|
||||||
|
// No models loaded at all
|
||||||
|
dispatch(refinerModelChanged(null));
|
||||||
|
dispatch(setShouldUseSDXLRefiner(false));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = zMainModel.safeParse(firstModel);
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
log.error(
|
||||||
|
{ error: result.error.format() },
|
||||||
|
'Failed to parse SDXL Refiner Model'
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(refinerModelChanged(result.data));
|
||||||
|
},
|
||||||
|
});
|
||||||
startAppListening({
|
startAppListening({
|
||||||
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
||||||
effect: async (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch }) => {
|
||||||
|
@ -3,6 +3,11 @@ import { modelsApi } from 'services/api/endpoints/models';
|
|||||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||||
import { startAppListening } from '../..';
|
import { startAppListening } from '../..';
|
||||||
|
import {
|
||||||
|
ALL_BASE_MODELS,
|
||||||
|
NON_REFINER_BASE_MODELS,
|
||||||
|
REFINER_BASE_MODELS,
|
||||||
|
} from 'services/api/constants';
|
||||||
|
|
||||||
export const addSocketConnectedEventListener = () => {
|
export const addSocketConnectedEventListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -24,7 +29,11 @@ export const addSocketConnectedEventListener = () => {
|
|||||||
dispatch(appSocketConnected(action.payload));
|
dispatch(appSocketConnected(action.payload));
|
||||||
|
|
||||||
// update all server state
|
// update all server state
|
||||||
dispatch(modelsApi.endpoints.getMainModels.initiate());
|
dispatch(modelsApi.endpoints.getMainModels.initiate(REFINER_BASE_MODELS));
|
||||||
|
dispatch(
|
||||||
|
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
|
||||||
|
);
|
||||||
|
dispatch(modelsApi.endpoints.getMainModels.initiate(ALL_BASE_MODELS));
|
||||||
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
|
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
|
||||||
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
|
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
|
||||||
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
|
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
|
||||||
|
@ -21,7 +21,10 @@ export const addInvocationStartedEventListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug(action.payload, 'Invocation started');
|
log.debug(
|
||||||
|
action.payload,
|
||||||
|
`Invocation started (${action.payload.data.node.type})`
|
||||||
|
);
|
||||||
dispatch(appSocketInvocationStarted(action.payload));
|
dispatch(appSocketInvocationStarted(action.payload));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||||
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
|
import {
|
||||||
|
MainModelConfigEntity,
|
||||||
|
modelsApi,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
|
import { startAppListening } from '..';
|
||||||
|
|
||||||
|
export const addTabChangedListener = () => {
|
||||||
|
startAppListening({
|
||||||
|
actionCreator: setActiveTab,
|
||||||
|
effect: (action, { getState, dispatch }) => {
|
||||||
|
const activeTabName = action.payload;
|
||||||
|
if (activeTabName === 'unifiedCanvas') {
|
||||||
|
// grab the models from RTK Query cache
|
||||||
|
const { data } = modelsApi.endpoints.getMainModels.select(
|
||||||
|
NON_REFINER_BASE_MODELS
|
||||||
|
)(getState());
|
||||||
|
|
||||||
|
if (!data) {
|
||||||
|
// no models yet, so we can't do anything
|
||||||
|
dispatch(modelChanged(null));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// need to filter out all the invalid canvas models (currently, this is just sdxl)
|
||||||
|
const validCanvasModels: MainModelConfigEntity[] = [];
|
||||||
|
|
||||||
|
forEach(data.entities, (entity) => {
|
||||||
|
if (!entity) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
|
||||||
|
validCanvasModels.push(entity);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// this could still be undefined even tho TS doesn't say so
|
||||||
|
const firstValidCanvasModel = validCanvasModels[0];
|
||||||
|
|
||||||
|
if (!firstValidCanvasModel) {
|
||||||
|
// uh oh, we have no models that are valid for canvas
|
||||||
|
dispatch(modelChanged(null));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// only store the model name and base model in redux
|
||||||
|
const { base_model, model_name } = firstValidCanvasModel;
|
||||||
|
|
||||||
|
dispatch(modelChanged({ base_model, model_name }));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
@ -3,6 +3,7 @@ import { userInvoked } from 'app/store/actions';
|
|||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||||
|
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
import { sessionCreated } from 'services/api/thunks/session';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
@ -14,8 +15,16 @@ export const addUserInvokedImageToImageListener = () => {
|
|||||||
effect: async (action, { getState, dispatch, take }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const log = logger('session');
|
const log = logger('session');
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
const model = state.generation.model;
|
||||||
|
|
||||||
|
let graph;
|
||||||
|
|
||||||
|
if (model && model.base_model === 'sdxl') {
|
||||||
|
graph = buildLinearSDXLImageToImageGraph(state);
|
||||||
|
} else {
|
||||||
|
graph = buildLinearImageToImageGraph(state);
|
||||||
|
}
|
||||||
|
|
||||||
const graph = buildLinearImageToImageGraph(state);
|
|
||||||
dispatch(imageToImageGraphBuilt(graph));
|
dispatch(imageToImageGraphBuilt(graph));
|
||||||
log.debug({ graph: parseify(graph) }, 'Image to Image graph built');
|
log.debug({ graph: parseify(graph) }, 'Image to Image graph built');
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
|||||||
import { userInvoked } from 'app/store/actions';
|
import { userInvoked } from 'app/store/actions';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||||
|
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
|
||||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||||
import { sessionCreated } from 'services/api/thunks/session';
|
import { sessionCreated } from 'services/api/thunks/session';
|
||||||
@ -14,8 +15,15 @@ export const addUserInvokedTextToImageListener = () => {
|
|||||||
effect: async (action, { getState, dispatch, take }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
const log = logger('session');
|
const log = logger('session');
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
const model = state.generation.model;
|
||||||
|
|
||||||
const graph = buildLinearTextToImageGraph(state);
|
let graph;
|
||||||
|
|
||||||
|
if (model && model.base_model === 'sdxl') {
|
||||||
|
graph = buildLinearSDXLTextToImageGraph(state);
|
||||||
|
} else {
|
||||||
|
graph = buildLinearTextToImageGraph(state);
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(textToImageGraphBuilt(graph));
|
dispatch(textToImageGraphBuilt(graph));
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ import loraReducer from 'features/lora/store/loraSlice';
|
|||||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||||
import generationReducer from 'features/parameters/store/generationSlice';
|
import generationReducer from 'features/parameters/store/generationSlice';
|
||||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||||
|
import sdxlReducer from 'features/sdxl/store/sdxlSlice';
|
||||||
import configReducer from 'features/system/store/configSlice';
|
import configReducer from 'features/system/store/configSlice';
|
||||||
import systemReducer from 'features/system/store/systemSlice';
|
import systemReducer from 'features/system/store/systemSlice';
|
||||||
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
||||||
@ -47,6 +48,7 @@ const allReducers = {
|
|||||||
imageDeletion: imageDeletionReducer,
|
imageDeletion: imageDeletionReducer,
|
||||||
lora: loraReducer,
|
lora: loraReducer,
|
||||||
modelmanager: modelmanagerReducer,
|
modelmanager: modelmanagerReducer,
|
||||||
|
sdxl: sdxlReducer,
|
||||||
[api.reducerPath]: api.reducer,
|
[api.reducerPath]: api.reducer,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -58,6 +60,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
|||||||
'canvas',
|
'canvas',
|
||||||
'gallery',
|
'gallery',
|
||||||
'generation',
|
'generation',
|
||||||
|
'sdxl',
|
||||||
'nodes',
|
'nodes',
|
||||||
'postprocessing',
|
'postprocessing',
|
||||||
'system',
|
'system',
|
||||||
|
@ -6,6 +6,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
|||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { modelsApi } from '../../services/api/endpoints/models';
|
import { modelsApi } from '../../services/api/endpoints/models';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
|
||||||
const readinessSelector = createSelector(
|
const readinessSelector = createSelector(
|
||||||
[stateSelector, activeTabNameSelector],
|
[stateSelector, activeTabNameSelector],
|
||||||
@ -24,7 +25,7 @@ const readinessSelector = createSelector(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
||||||
modelsApi.endpoints.getMainModels.select()(state);
|
modelsApi.endpoints.getMainModels.select(ALL_BASE_MODELS)(state);
|
||||||
if (!mainModelsSuccessfullyLoaded) {
|
if (!mainModelsSuccessfullyLoaded) {
|
||||||
isReady = false;
|
isReady = false;
|
||||||
reasonsWhyNotReady.push('Models are not loaded');
|
reasonsWhyNotReady.push('Models are not loaded');
|
||||||
|
@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
|||||||
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
|
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
|
||||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||||
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||||
|
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
|
||||||
|
|
||||||
type InputFieldComponentProps = {
|
type InputFieldComponentProps = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'refiner_model' && template.type === 'refiner_model') {
|
||||||
|
return (
|
||||||
|
<RefinerModelInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'vae_model' && template.type === 'vae_model') {
|
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||||
return (
|
return (
|
||||||
<VaeModelInputFieldComponent
|
<VaeModelInputFieldComponent
|
||||||
|
@ -14,6 +14,7 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
|
|||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { FieldComponentProps } from './types';
|
import { FieldComponentProps } from './types';
|
||||||
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
||||||
@ -27,7 +28,9 @@ const ModelInputFieldComponent = (
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||||
|
|
||||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||||
|
NON_REFINER_BASE_MODELS
|
||||||
|
);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!mainModels) {
|
if (!mainModels) {
|
||||||
|
@ -0,0 +1,120 @@
|
|||||||
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
RefinerModelInputFieldTemplate,
|
||||||
|
RefinerModelInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||||
|
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
|
||||||
|
const RefinerModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
RefinerModelInputFieldValue,
|
||||||
|
RefinerModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||||
|
const { data: refinerModels, isLoading } =
|
||||||
|
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!refinerModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(refinerModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [refinerModels]);
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() =>
|
||||||
|
refinerModels?.entities[
|
||||||
|
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||||
|
] ?? null,
|
||||||
|
[field.value?.base_model, field.value?.model_name, refinerModels?.entities]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChangeModel = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newModel = modelIdToMainModelParam(v);
|
||||||
|
|
||||||
|
if (!newModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: newModel,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
return isLoading ? (
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
label={t('modelManager.model')}
|
||||||
|
placeholder="Loading..."
|
||||||
|
disabled={true}
|
||||||
|
data={[]}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Flex w="100%" alignItems="center" gap={2}>
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={selectedModel?.id}
|
||||||
|
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||||
|
data={data}
|
||||||
|
error={data.length === 0}
|
||||||
|
disabled={data.length === 0}
|
||||||
|
onChange={handleChangeModel}
|
||||||
|
/>
|
||||||
|
{isSyncModelEnabled && (
|
||||||
|
<Box mt={7}>
|
||||||
|
<SyncModelsButton iconMode />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(RefinerModelInputFieldComponent);
|
@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
|||||||
ClipField: 'clip',
|
ClipField: 'clip',
|
||||||
VaeField: 'vae',
|
VaeField: 'vae',
|
||||||
model: 'model',
|
model: 'model',
|
||||||
|
refiner_model: 'refiner_model',
|
||||||
vae_model: 'vae_model',
|
vae_model: 'vae_model',
|
||||||
lora_model: 'lora_model',
|
lora_model: 'lora_model',
|
||||||
controlnet_model: 'controlnet_model',
|
controlnet_model: 'controlnet_model',
|
||||||
@ -120,6 +121,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'Model',
|
title: 'Model',
|
||||||
description: 'Models are models.',
|
description: 'Models are models.',
|
||||||
},
|
},
|
||||||
|
refiner_model: {
|
||||||
|
color: 'teal',
|
||||||
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
title: 'Refiner Model',
|
||||||
|
description: 'Models are models.',
|
||||||
|
},
|
||||||
vae_model: {
|
vae_model: {
|
||||||
color: 'teal',
|
color: 'teal',
|
||||||
colorCssVar: getColorTokenCssVariable('teal'),
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
@ -70,6 +70,7 @@ export type FieldType =
|
|||||||
| 'vae'
|
| 'vae'
|
||||||
| 'control'
|
| 'control'
|
||||||
| 'model'
|
| 'model'
|
||||||
|
| 'refiner_model'
|
||||||
| 'vae_model'
|
| 'vae_model'
|
||||||
| 'lora_model'
|
| 'lora_model'
|
||||||
| 'controlnet_model'
|
| 'controlnet_model'
|
||||||
@ -100,6 +101,7 @@ export type InputFieldValue =
|
|||||||
| ControlInputFieldValue
|
| ControlInputFieldValue
|
||||||
| EnumInputFieldValue
|
| EnumInputFieldValue
|
||||||
| MainModelInputFieldValue
|
| MainModelInputFieldValue
|
||||||
|
| RefinerModelInputFieldValue
|
||||||
| VaeModelInputFieldValue
|
| VaeModelInputFieldValue
|
||||||
| LoRAModelInputFieldValue
|
| LoRAModelInputFieldValue
|
||||||
| ControlNetModelInputFieldValue
|
| ControlNetModelInputFieldValue
|
||||||
@ -128,6 +130,7 @@ export type InputFieldTemplate =
|
|||||||
| ControlInputFieldTemplate
|
| ControlInputFieldTemplate
|
||||||
| EnumInputFieldTemplate
|
| EnumInputFieldTemplate
|
||||||
| ModelInputFieldTemplate
|
| ModelInputFieldTemplate
|
||||||
|
| RefinerModelInputFieldTemplate
|
||||||
| VaeModelInputFieldTemplate
|
| VaeModelInputFieldTemplate
|
||||||
| LoRAModelInputFieldTemplate
|
| LoRAModelInputFieldTemplate
|
||||||
| ControlNetModelInputFieldTemplate
|
| ControlNetModelInputFieldTemplate
|
||||||
@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & {
|
|||||||
value?: MainModelParam;
|
value?: MainModelParam;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type RefinerModelInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'refiner_model';
|
||||||
|
value?: MainModelParam;
|
||||||
|
};
|
||||||
|
|
||||||
export type VaeModelInputFieldValue = FieldValueBase & {
|
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||||
type: 'vae_model';
|
type: 'vae_model';
|
||||||
value?: VaeModelParam;
|
value?: VaeModelParam;
|
||||||
@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'model';
|
type: 'model';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'refiner_model';
|
||||||
|
};
|
||||||
|
|
||||||
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: string;
|
default: string;
|
||||||
type: 'vae_model';
|
type: 'vae_model';
|
||||||
|
@ -22,6 +22,7 @@ import {
|
|||||||
LoRAModelInputFieldTemplate,
|
LoRAModelInputFieldTemplate,
|
||||||
ModelInputFieldTemplate,
|
ModelInputFieldTemplate,
|
||||||
OutputFieldTemplate,
|
OutputFieldTemplate,
|
||||||
|
RefinerModelInputFieldTemplate,
|
||||||
StringInputFieldTemplate,
|
StringInputFieldTemplate,
|
||||||
TypeHints,
|
TypeHints,
|
||||||
UNetInputFieldTemplate,
|
UNetInputFieldTemplate,
|
||||||
@ -178,6 +179,21 @@ const buildModelInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildRefinerModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
|
||||||
|
const template: RefinerModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'refiner_model',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'direct',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildVaeModelInputFieldTemplate = ({
|
const buildVaeModelInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -492,6 +508,9 @@ export const buildInputFieldTemplate = (
|
|||||||
if (['model'].includes(fieldType)) {
|
if (['model'].includes(fieldType)) {
|
||||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
if (['refiner_model'].includes(fieldType)) {
|
||||||
|
return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
if (['vae_model'].includes(fieldType)) {
|
if (['vae_model'].includes(fieldType)) {
|
||||||
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
@ -76,6 +76,10 @@ export const buildInputFieldValue = (
|
|||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'refiner_model') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
|
|
||||||
if (template.type === 'vae_model') {
|
if (template.type === 'vae_model') {
|
||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,186 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||||
|
import { NonNullableGraph } from '../../types/types';
|
||||||
|
import {
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
METADATA_ACCUMULATOR,
|
||||||
|
SDXL_LATENTS_TO_LATENTS,
|
||||||
|
SDXL_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||||
|
SDXL_REFINER_MODEL_LOADER,
|
||||||
|
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||||
|
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||||
|
} from './constants';
|
||||||
|
|
||||||
|
export const addSDXLRefinerToGraph = (
|
||||||
|
state: RootState,
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
baseNodeId: string
|
||||||
|
): void => {
|
||||||
|
const { positivePrompt, negativePrompt } = state.generation;
|
||||||
|
const {
|
||||||
|
refinerModel,
|
||||||
|
refinerAestheticScore,
|
||||||
|
positiveStylePrompt,
|
||||||
|
negativeStylePrompt,
|
||||||
|
refinerSteps,
|
||||||
|
refinerScheduler,
|
||||||
|
refinerCFGScale,
|
||||||
|
refinerStart,
|
||||||
|
} = state.sdxl;
|
||||||
|
|
||||||
|
if (!refinerModel) return;
|
||||||
|
|
||||||
|
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||||
|
| MetadataAccumulatorInvocation
|
||||||
|
| undefined;
|
||||||
|
|
||||||
|
if (metadataAccumulator) {
|
||||||
|
metadataAccumulator.refiner_model = refinerModel;
|
||||||
|
metadataAccumulator.refiner_aesthetic_store = refinerAestheticScore;
|
||||||
|
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
||||||
|
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
||||||
|
metadataAccumulator.refiner_start = refinerStart;
|
||||||
|
metadataAccumulator.refiner_steps = refinerSteps;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unplug SDXL Latents Generation To Latents To Image
|
||||||
|
graph.edges = graph.edges.filter(
|
||||||
|
(e) =>
|
||||||
|
!(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field))
|
||||||
|
);
|
||||||
|
|
||||||
|
graph.edges = graph.edges.filter(
|
||||||
|
(e) =>
|
||||||
|
!(
|
||||||
|
e.source.node_id === SDXL_MODEL_LOADER &&
|
||||||
|
['vae'].includes(e.source.field)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
// connect the VAE back to the i2l, which we just removed in the filter
|
||||||
|
// but only if we are doing l2l
|
||||||
|
if (baseNodeId === SDXL_LATENTS_TO_LATENTS) {
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
|
||||||
|
type: 'sdxl_refiner_model_loader',
|
||||||
|
id: SDXL_REFINER_MODEL_LOADER,
|
||||||
|
model: refinerModel,
|
||||||
|
};
|
||||||
|
graph.nodes[SDXL_REFINER_POSITIVE_CONDITIONING] = {
|
||||||
|
type: 'sdxl_refiner_compel_prompt',
|
||||||
|
id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||||
|
style: `${positivePrompt} ${positiveStylePrompt}`,
|
||||||
|
aesthetic_score: refinerAestheticScore,
|
||||||
|
};
|
||||||
|
graph.nodes[SDXL_REFINER_NEGATIVE_CONDITIONING] = {
|
||||||
|
type: 'sdxl_refiner_compel_prompt',
|
||||||
|
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||||
|
style: `${negativePrompt} ${negativeStylePrompt}`,
|
||||||
|
aesthetic_score: refinerAestheticScore,
|
||||||
|
};
|
||||||
|
graph.nodes[SDXL_REFINER_LATENTS_TO_LATENTS] = {
|
||||||
|
type: 'l2l_sdxl',
|
||||||
|
id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||||
|
cfg_scale: refinerCFGScale,
|
||||||
|
steps: refinerSteps / (1 - Math.min(refinerStart, 0.99)),
|
||||||
|
scheduler: refinerScheduler,
|
||||||
|
denoising_start: refinerStart,
|
||||||
|
denoising_end: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.edges.push(
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: baseNodeId,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
};
|
@ -46,6 +46,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
clipSkip,
|
clipSkip,
|
||||||
shouldUseCpuNoise,
|
shouldUseCpuNoise,
|
||||||
shouldUseNoiseSettings,
|
shouldUseNoiseSettings,
|
||||||
|
vaePrecision,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
// TODO: add batch functionality
|
// TODO: add batch functionality
|
||||||
@ -113,6 +114,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
|
fp32: vaePrecision === 'fp32' ? true : false,
|
||||||
},
|
},
|
||||||
[LATENTS_TO_LATENTS]: {
|
[LATENTS_TO_LATENTS]: {
|
||||||
type: 'l2l',
|
type: 'l2l',
|
||||||
@ -129,6 +131,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
// image: {
|
// image: {
|
||||||
// image_name: initialImage.image_name,
|
// image_name: initialImage.image_name,
|
||||||
// },
|
// },
|
||||||
|
fp32: vaePrecision === 'fp32' ? true : false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
edges: [
|
edges: [
|
||||||
|
@ -0,0 +1,369 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||||
|
import {
|
||||||
|
ImageResizeInvocation,
|
||||||
|
ImageToLatentsInvocation,
|
||||||
|
} from 'services/api/types';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||||
|
import {
|
||||||
|
IMAGE_TO_LATENTS,
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
METADATA_ACCUMULATOR,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
NOISE,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
RESIZE,
|
||||||
|
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
SDXL_LATENTS_TO_LATENTS,
|
||||||
|
SDXL_MODEL_LOADER,
|
||||||
|
} from './constants';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds the Image to Image tab graph.
|
||||||
|
*/
|
||||||
|
export const buildLinearSDXLImageToImageGraph = (
|
||||||
|
state: RootState
|
||||||
|
): NonNullableGraph => {
|
||||||
|
const log = logger('nodes');
|
||||||
|
const {
|
||||||
|
positivePrompt,
|
||||||
|
negativePrompt,
|
||||||
|
model,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
initialImage,
|
||||||
|
shouldFitToWidthHeight,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
clipSkip,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
shouldUseNoiseSettings,
|
||||||
|
vaePrecision,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
|
const {
|
||||||
|
positiveStylePrompt,
|
||||||
|
negativeStylePrompt,
|
||||||
|
shouldUseSDXLRefiner,
|
||||||
|
refinerStart,
|
||||||
|
sdxlImg2ImgDenoisingStrength: strength,
|
||||||
|
} = state.sdxl;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
|
* ids.
|
||||||
|
*
|
||||||
|
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||||
|
* the `fit` param. These are added to the graph at the end.
|
||||||
|
*/
|
||||||
|
|
||||||
|
if (!initialImage) {
|
||||||
|
log.error('No initial image found in state');
|
||||||
|
throw new Error('No initial image found in state');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!model) {
|
||||||
|
log.error('No model found in state');
|
||||||
|
throw new Error('No model found in state');
|
||||||
|
}
|
||||||
|
|
||||||
|
const use_cpu = shouldUseNoiseSettings
|
||||||
|
? shouldUseCpuNoise
|
||||||
|
: initialGenerationState.shouldUseCpuNoise;
|
||||||
|
|
||||||
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
const graph: NonNullableGraph = {
|
||||||
|
id: SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||||
|
nodes: {
|
||||||
|
[SDXL_MODEL_LOADER]: {
|
||||||
|
type: 'sdxl_model_loader',
|
||||||
|
id: SDXL_MODEL_LOADER,
|
||||||
|
model,
|
||||||
|
},
|
||||||
|
[POSITIVE_CONDITIONING]: {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
style: positiveStylePrompt,
|
||||||
|
},
|
||||||
|
[NEGATIVE_CONDITIONING]: {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
style: negativeStylePrompt,
|
||||||
|
},
|
||||||
|
[NOISE]: {
|
||||||
|
type: 'noise',
|
||||||
|
id: NOISE,
|
||||||
|
use_cpu,
|
||||||
|
},
|
||||||
|
[LATENTS_TO_IMAGE]: {
|
||||||
|
type: 'l2i',
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
fp32: vaePrecision === 'fp32' ? true : false,
|
||||||
|
},
|
||||||
|
[SDXL_LATENTS_TO_LATENTS]: {
|
||||||
|
type: 'l2l_sdxl',
|
||||||
|
id: SDXL_LATENTS_TO_LATENTS,
|
||||||
|
cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
denoising_start: shouldUseSDXLRefiner
|
||||||
|
? Math.min(refinerStart, 1 - strength)
|
||||||
|
: 1 - strength,
|
||||||
|
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
|
||||||
|
},
|
||||||
|
[IMAGE_TO_LATENTS]: {
|
||||||
|
type: 'i2l',
|
||||||
|
id: IMAGE_TO_LATENTS,
|
||||||
|
// must be set manually later, bc `fit` parameter may require a resize node inserted
|
||||||
|
// image: {
|
||||||
|
// image_name: initialImage.image_name,
|
||||||
|
// },
|
||||||
|
fp32: vaePrecision === 'fp32' ? true : false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// handle `fit`
|
||||||
|
if (
|
||||||
|
shouldFitToWidthHeight &&
|
||||||
|
(initialImage.width !== width || initialImage.height !== height)
|
||||||
|
) {
|
||||||
|
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||||
|
|
||||||
|
// Create a resize node, explicitly setting its image
|
||||||
|
const resizeNode: ImageResizeInvocation = {
|
||||||
|
id: RESIZE,
|
||||||
|
type: 'img_resize',
|
||||||
|
image: {
|
||||||
|
image_name: initialImage.imageName,
|
||||||
|
},
|
||||||
|
is_intermediate: true,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.nodes[RESIZE] = resizeNode;
|
||||||
|
|
||||||
|
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'image' },
|
||||||
|
destination: {
|
||||||
|
node_id: IMAGE_TO_LATENTS,
|
||||||
|
field: 'image',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: RESIZE, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||||
|
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
|
||||||
|
image_name: initialImage.imageName,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pass the image's dimensions to the `NOISE` node
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'width',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'height',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||||
|
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||||
|
id: METADATA_ACCUMULATOR,
|
||||||
|
type: 'metadata_accumulator',
|
||||||
|
generation_mode: 'sdxl_img2img',
|
||||||
|
cfg_scale,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||||
|
negative_prompt: negativePrompt,
|
||||||
|
model,
|
||||||
|
seed: 0, // set in addDynamicPromptsToGraph
|
||||||
|
steps,
|
||||||
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
|
scheduler,
|
||||||
|
vae: undefined,
|
||||||
|
controlnets: [],
|
||||||
|
loras: [],
|
||||||
|
clip_skip: clipSkip,
|
||||||
|
strength: strength,
|
||||||
|
init_image: initialImage.imageName,
|
||||||
|
positive_style_prompt: positiveStylePrompt,
|
||||||
|
negative_style_prompt: negativeStylePrompt,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: METADATA_ACCUMULATOR,
|
||||||
|
field: 'metadata',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'metadata',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add Refiner if enabled
|
||||||
|
if (shouldUseSDXLRefiner) {
|
||||||
|
addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add dynamic prompts - also sets up core iteration and seed
|
||||||
|
addDynamicPromptsToGraph(state, graph);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
};
|
@ -0,0 +1,251 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||||
|
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||||
|
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||||
|
import {
|
||||||
|
LATENTS_TO_IMAGE,
|
||||||
|
METADATA_ACCUMULATOR,
|
||||||
|
NEGATIVE_CONDITIONING,
|
||||||
|
NOISE,
|
||||||
|
POSITIVE_CONDITIONING,
|
||||||
|
SDXL_MODEL_LOADER,
|
||||||
|
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||||
|
SDXL_TEXT_TO_LATENTS,
|
||||||
|
} from './constants';
|
||||||
|
|
||||||
|
export const buildLinearSDXLTextToImageGraph = (
|
||||||
|
state: RootState
|
||||||
|
): NonNullableGraph => {
|
||||||
|
const log = logger('nodes');
|
||||||
|
const {
|
||||||
|
positivePrompt,
|
||||||
|
negativePrompt,
|
||||||
|
model,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
clipSkip,
|
||||||
|
shouldUseCpuNoise,
|
||||||
|
shouldUseNoiseSettings,
|
||||||
|
vaePrecision,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
|
const {
|
||||||
|
positiveStylePrompt,
|
||||||
|
negativeStylePrompt,
|
||||||
|
shouldUseSDXLRefiner,
|
||||||
|
refinerStart,
|
||||||
|
} = state.sdxl;
|
||||||
|
|
||||||
|
const use_cpu = shouldUseNoiseSettings
|
||||||
|
? shouldUseCpuNoise
|
||||||
|
: initialGenerationState.shouldUseCpuNoise;
|
||||||
|
|
||||||
|
if (!model) {
|
||||||
|
log.error('No model found in state');
|
||||||
|
throw new Error('No model found in state');
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||||
|
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||||
|
* ids.
|
||||||
|
*
|
||||||
|
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||||
|
* the `fit` param. These are added to the graph at the end.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||||
|
const graph: NonNullableGraph = {
|
||||||
|
id: SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||||
|
nodes: {
|
||||||
|
[SDXL_MODEL_LOADER]: {
|
||||||
|
type: 'sdxl_model_loader',
|
||||||
|
id: SDXL_MODEL_LOADER,
|
||||||
|
model,
|
||||||
|
},
|
||||||
|
[POSITIVE_CONDITIONING]: {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
style: positiveStylePrompt,
|
||||||
|
},
|
||||||
|
[NEGATIVE_CONDITIONING]: {
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
style: negativeStylePrompt,
|
||||||
|
},
|
||||||
|
[NOISE]: {
|
||||||
|
type: 'noise',
|
||||||
|
id: NOISE,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
use_cpu,
|
||||||
|
},
|
||||||
|
[SDXL_TEXT_TO_LATENTS]: {
|
||||||
|
type: 't2l_sdxl',
|
||||||
|
id: SDXL_TEXT_TO_LATENTS,
|
||||||
|
cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
|
||||||
|
},
|
||||||
|
[LATENTS_TO_IMAGE]: {
|
||||||
|
type: 'l2i',
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
fp32: vaePrecision === 'fp32' ? true : false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
edges: [
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_TEXT_TO_LATENTS,
|
||||||
|
field: 'unet',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'vae',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_MODEL_LOADER,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'clip2',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: POSITIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_TEXT_TO_LATENTS,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NEGATIVE_CONDITIONING,
|
||||||
|
field: 'conditioning',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_TEXT_TO_LATENTS,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: SDXL_TEXT_TO_LATENTS,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
source: {
|
||||||
|
node_id: SDXL_TEXT_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||||
|
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||||
|
id: METADATA_ACCUMULATOR,
|
||||||
|
type: 'metadata_accumulator',
|
||||||
|
generation_mode: 'sdxl_txt2img',
|
||||||
|
cfg_scale,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||||
|
negative_prompt: negativePrompt,
|
||||||
|
model,
|
||||||
|
seed: 0, // set in addDynamicPromptsToGraph
|
||||||
|
steps,
|
||||||
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
|
scheduler,
|
||||||
|
vae: undefined,
|
||||||
|
controlnets: [],
|
||||||
|
loras: [],
|
||||||
|
clip_skip: clipSkip,
|
||||||
|
positive_style_prompt: positiveStylePrompt,
|
||||||
|
negative_style_prompt: negativeStylePrompt,
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: METADATA_ACCUMULATOR,
|
||||||
|
field: 'metadata',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'metadata',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add Refiner if enabled
|
||||||
|
if (shouldUseSDXLRefiner) {
|
||||||
|
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add dynamic prompts - also sets up core iteration and seed
|
||||||
|
addDynamicPromptsToGraph(state, graph);
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
};
|
@ -34,6 +34,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
clipSkip,
|
clipSkip,
|
||||||
shouldUseCpuNoise,
|
shouldUseCpuNoise,
|
||||||
shouldUseNoiseSettings,
|
shouldUseNoiseSettings,
|
||||||
|
vaePrecision,
|
||||||
} = state.generation;
|
} = state.generation;
|
||||||
|
|
||||||
const use_cpu = shouldUseNoiseSettings
|
const use_cpu = shouldUseNoiseSettings
|
||||||
@ -95,6 +96,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
[LATENTS_TO_IMAGE]: {
|
[LATENTS_TO_IMAGE]: {
|
||||||
type: 'l2i',
|
type: 'l2i',
|
||||||
id: LATENTS_TO_IMAGE,
|
id: LATENTS_TO_IMAGE,
|
||||||
|
fp32: vaePrecision === 'fp32' ? true : false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
edges: [
|
edges: [
|
||||||
|
@ -23,8 +23,19 @@ export const METADATA_ACCUMULATOR = 'metadata_accumulator';
|
|||||||
export const REALESRGAN = 'esrgan';
|
export const REALESRGAN = 'esrgan';
|
||||||
export const DIVIDE = 'divide';
|
export const DIVIDE = 'divide';
|
||||||
export const SCALE = 'scale_image';
|
export const SCALE = 'scale_image';
|
||||||
|
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
|
||||||
|
export const SDXL_TEXT_TO_LATENTS = 't2l_sdxl';
|
||||||
|
export const SDXL_LATENTS_TO_LATENTS = 'l2l_sdxl';
|
||||||
|
export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader';
|
||||||
|
export const SDXL_REFINER_POSITIVE_CONDITIONING =
|
||||||
|
'sdxl_refiner_positive_conditioning';
|
||||||
|
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
|
||||||
|
'sdxl_refiner_negative_conditioning';
|
||||||
|
export const SDXL_REFINER_LATENTS_TO_LATENTS = 'l2l_sdxl_refiner';
|
||||||
|
|
||||||
// friendly graph ids
|
// friendly graph ids
|
||||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||||
|
export const SDXL_TEXT_TO_IMAGE_GRAPH = 'sdxl_text_to_image_graph';
|
||||||
|
export const SDXL_IMAGE_TO_IMAGE_GRAPH = 'sxdl_image_to_image_graph';
|
||||||
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
|
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
|
||||||
export const INPAINT_GRAPH = 'inpaint_graph';
|
export const INPAINT_GRAPH = 'inpaint_graph';
|
||||||
|
@ -4,6 +4,7 @@ import { memo } from 'react';
|
|||||||
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
|
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
|
||||||
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
|
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
|
||||||
import ParamScheduler from './ParamScheduler';
|
import ParamScheduler from './ParamScheduler';
|
||||||
|
import ParamVAEPrecision from '../VAEModel/ParamVAEPrecision';
|
||||||
|
|
||||||
const ParamModelandVAEandScheduler = () => {
|
const ParamModelandVAEandScheduler = () => {
|
||||||
const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled;
|
const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled;
|
||||||
@ -13,16 +14,15 @@ const ParamModelandVAEandScheduler = () => {
|
|||||||
<Box w="full">
|
<Box w="full">
|
||||||
<ParamMainModelSelect />
|
<ParamMainModelSelect />
|
||||||
</Box>
|
</Box>
|
||||||
<Flex gap={3} w="full">
|
<Box w="full">
|
||||||
{isVaeEnabled && (
|
<ParamScheduler />
|
||||||
<Box w="full">
|
</Box>
|
||||||
<ParamVAEModelSelect />
|
{isVaeEnabled && (
|
||||||
</Box>
|
<Flex w="full" gap={3}>
|
||||||
)}
|
<ParamVAEModelSelect />
|
||||||
<Box w="full">
|
<ParamVAEPrecision />
|
||||||
<ParamScheduler />
|
</Flex>
|
||||||
</Box>
|
)}
|
||||||
</Flex>
|
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -13,7 +13,9 @@ import { modelSelected } from 'features/parameters/store/actions';
|
|||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
|
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||||
|
|
||||||
@ -29,8 +31,12 @@ const ParamMainModelSelect = () => {
|
|||||||
|
|
||||||
const { model } = useAppSelector(selector);
|
const { model } = useAppSelector(selector);
|
||||||
|
|
||||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
|
||||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||||
|
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||||
|
NON_REFINER_BASE_MODELS
|
||||||
|
);
|
||||||
|
|
||||||
|
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||||
|
|
||||||
const data = useMemo(() => {
|
const data = useMemo(() => {
|
||||||
if (!mainModels) {
|
if (!mainModels) {
|
||||||
@ -40,7 +46,10 @@ const ParamMainModelSelect = () => {
|
|||||||
const data: SelectItem[] = [];
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
forEach(mainModels.entities, (model, id) => {
|
forEach(mainModels.entities, (model, id) => {
|
||||||
if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) {
|
if (
|
||||||
|
!model ||
|
||||||
|
(activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl')
|
||||||
|
) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +61,7 @@ const ParamMainModelSelect = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}, [mainModels]);
|
}, [mainModels, activeTabName]);
|
||||||
|
|
||||||
// grab the full model entity from the RTK Query cache
|
// grab the full model entity from the RTK Query cache
|
||||||
// TODO: maybe we should just store the full model entity in state?
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
@ -88,7 +97,7 @@ const ParamMainModelSelect = () => {
|
|||||||
data={[]}
|
data={[]}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<Flex w="100%" alignItems="center" gap={2}>
|
<Flex w="100%" alignItems="center" gap={3}>
|
||||||
<IAIMantineSearchableSelect
|
<IAIMantineSearchableSelect
|
||||||
tooltip={selectedModel?.description}
|
tooltip={selectedModel?.description}
|
||||||
label={t('modelManager.model')}
|
label={t('modelManager.model')}
|
||||||
|
@ -32,11 +32,6 @@ export default function ParamSeed() {
|
|||||||
isInvalid={seed < 0 && shouldGenerateVariations}
|
isInvalid={seed < 0 && shouldGenerateVariations}
|
||||||
onChange={handleChangeSeed}
|
onChange={handleChangeSeed}
|
||||||
value={seed}
|
value={seed}
|
||||||
formControlProps={{
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
gap: 3, // really this should work with 2 but seems to need to be 3 to match gap 2?
|
|
||||||
}}
|
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import ParamSeedRandomize from './ParamSeedRandomize';
|
|||||||
|
|
||||||
const ParamSeedFull = () => {
|
const ParamSeedFull = () => {
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ gap: 4, alignItems: 'center' }}>
|
<Flex sx={{ gap: 3, alignItems: 'flex-end' }}>
|
||||||
<ParamSeed />
|
<ParamSeed />
|
||||||
<ParamSeedShuffle />
|
<ParamSeedShuffle />
|
||||||
<ParamSeedRandomize />
|
<ParamSeedRandomize />
|
||||||
|
@ -0,0 +1,46 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { vaePrecisionChanged } from 'features/parameters/store/generationSlice';
|
||||||
|
import { PrecisionParam } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ generation }) => {
|
||||||
|
const { vaePrecision } = generation;
|
||||||
|
return { vaePrecision };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const DATA = ['fp16', 'fp32'];
|
||||||
|
|
||||||
|
const ParamVAEModelSelect = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { vaePrecision } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(vaePrecisionChanged(v as PrecisionParam));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
label="VAE Precision"
|
||||||
|
value={vaePrecision}
|
||||||
|
data={DATA}
|
||||||
|
onChange={handleChange}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamVAEModelSelect);
|
@ -11,6 +11,7 @@ import {
|
|||||||
MainModelParam,
|
MainModelParam,
|
||||||
NegativePromptParam,
|
NegativePromptParam,
|
||||||
PositivePromptParam,
|
PositivePromptParam,
|
||||||
|
PrecisionParam,
|
||||||
SchedulerParam,
|
SchedulerParam,
|
||||||
SeedParam,
|
SeedParam,
|
||||||
StepsParam,
|
StepsParam,
|
||||||
@ -51,6 +52,7 @@ export interface GenerationState {
|
|||||||
verticalSymmetrySteps: number;
|
verticalSymmetrySteps: number;
|
||||||
model: MainModelField | null;
|
model: MainModelField | null;
|
||||||
vae: VaeModelParam | null;
|
vae: VaeModelParam | null;
|
||||||
|
vaePrecision: PrecisionParam;
|
||||||
seamlessXAxis: boolean;
|
seamlessXAxis: boolean;
|
||||||
seamlessYAxis: boolean;
|
seamlessYAxis: boolean;
|
||||||
clipSkip: number;
|
clipSkip: number;
|
||||||
@ -89,6 +91,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
verticalSymmetrySteps: 0,
|
verticalSymmetrySteps: 0,
|
||||||
model: null,
|
model: null,
|
||||||
vae: null,
|
vae: null,
|
||||||
|
vaePrecision: 'fp32',
|
||||||
seamlessXAxis: false,
|
seamlessXAxis: false,
|
||||||
seamlessYAxis: false,
|
seamlessYAxis: false,
|
||||||
clipSkip: 0,
|
clipSkip: 0,
|
||||||
@ -241,6 +244,9 @@ export const generationSlice = createSlice({
|
|||||||
// null is a valid VAE!
|
// null is a valid VAE!
|
||||||
state.vae = action.payload;
|
state.vae = action.payload;
|
||||||
},
|
},
|
||||||
|
vaePrecisionChanged: (state, action: PayloadAction<PrecisionParam>) => {
|
||||||
|
state.vaePrecision = action.payload;
|
||||||
|
},
|
||||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||||
state.clipSkip = action.payload;
|
state.clipSkip = action.payload;
|
||||||
},
|
},
|
||||||
@ -327,6 +333,7 @@ export const {
|
|||||||
shouldUseCpuNoiseChanged,
|
shouldUseCpuNoiseChanged,
|
||||||
setShouldShowAdvancedOptions,
|
setShouldShowAdvancedOptions,
|
||||||
setAspectRatio,
|
setAspectRatio,
|
||||||
|
vaePrecisionChanged,
|
||||||
} = generationSlice.actions;
|
} = generationSlice.actions;
|
||||||
|
|
||||||
export default generationSlice.reducer;
|
export default generationSlice.reducer;
|
||||||
|
@ -42,6 +42,42 @@ export const isValidNegativePrompt = (
|
|||||||
val: unknown
|
val: unknown
|
||||||
): val is NegativePromptParam => zNegativePrompt.safeParse(val).success;
|
): val is NegativePromptParam => zNegativePrompt.safeParse(val).success;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Zod schema for SDXL positive style prompt parameter
|
||||||
|
*/
|
||||||
|
export const zPositiveStylePromptSDXL = z.string();
|
||||||
|
/**
|
||||||
|
* Type alias for SDXL positive style prompt parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type PositiveStylePromptSDXLParam = z.infer<
|
||||||
|
typeof zPositiveStylePromptSDXL
|
||||||
|
>;
|
||||||
|
/**
|
||||||
|
* Validates/type-guards a value as a SDXL positive style prompt parameter
|
||||||
|
*/
|
||||||
|
export const isValidSDXLPositiveStylePrompt = (
|
||||||
|
val: unknown
|
||||||
|
): val is PositiveStylePromptSDXLParam =>
|
||||||
|
zPositiveStylePromptSDXL.safeParse(val).success;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Zod schema for SDXL negative style prompt parameter
|
||||||
|
*/
|
||||||
|
export const zNegativeStylePromptSDXL = z.string();
|
||||||
|
/**
|
||||||
|
* Type alias for SDXL negative style prompt parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type NegativeStylePromptSDXLParam = z.infer<
|
||||||
|
typeof zNegativeStylePromptSDXL
|
||||||
|
>;
|
||||||
|
/**
|
||||||
|
* Validates/type-guards a value as a SDXL negative style prompt parameter
|
||||||
|
*/
|
||||||
|
export const isValidSDXLNegativeStylePrompt = (
|
||||||
|
val: unknown
|
||||||
|
): val is NegativeStylePromptSDXLParam =>
|
||||||
|
zNegativeStylePromptSDXL.safeParse(val).success;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Zod schema for steps parameter
|
* Zod schema for steps parameter
|
||||||
*/
|
*/
|
||||||
@ -260,6 +296,20 @@ export type StrengthParam = z.infer<typeof zStrength>;
|
|||||||
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
||||||
zStrength.safeParse(val).success;
|
zStrength.safeParse(val).success;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Zod schema for a precision parameter
|
||||||
|
*/
|
||||||
|
export const zPrecision = z.enum(['fp16', 'fp32']);
|
||||||
|
/**
|
||||||
|
* Type alias for precision parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type PrecisionParam = z.infer<typeof zPrecision>;
|
||||||
|
/**
|
||||||
|
* Validates/type-guards a value as a precision parameter
|
||||||
|
*/
|
||||||
|
export const isValidPrecision = (val: unknown): val is PrecisionParam =>
|
||||||
|
zPrecision.safeParse(val).success;
|
||||||
|
|
||||||
// /**
|
// /**
|
||||||
// * Zod schema for BaseModelType
|
// * Zod schema for BaseModelType
|
||||||
// */
|
// */
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { setSDXLImg2ImgDenoisingStrength } from '../store/sdxlSlice';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ sdxl }) => {
|
||||||
|
const { sdxlImg2ImgDenoisingStrength } = sdxl;
|
||||||
|
|
||||||
|
return {
|
||||||
|
sdxlImg2ImgDenoisingStrength,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLImg2ImgDenoisingStrength = () => {
|
||||||
|
const { sdxlImg2ImgDenoisingStrength } = useAppSelector(selector);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => dispatch(setSDXLImg2ImgDenoisingStrength(v)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReset = useCallback(() => {
|
||||||
|
dispatch(setSDXLImg2ImgDenoisingStrength(0.7));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISlider
|
||||||
|
label={`${t('parameters.denoisingStrength')}`}
|
||||||
|
step={0.01}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
onChange={handleChange}
|
||||||
|
handleReset={handleReset}
|
||||||
|
value={sdxlImg2ImgDenoisingStrength}
|
||||||
|
isInteger={false}
|
||||||
|
withInput
|
||||||
|
withSliderMarks
|
||||||
|
withReset
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamSDXLImg2ImgDenoisingStrength);
|
@ -0,0 +1,149 @@
|
|||||||
|
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||||
|
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
|
||||||
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
import IAITextarea from 'common/components/IAITextarea';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||||
|
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||||
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { flushSync } from 'react-dom';
|
||||||
|
import { setNegativeStylePromptSDXL } from '../store/sdxlSlice';
|
||||||
|
|
||||||
|
const promptInputSelector = createSelector(
|
||||||
|
[stateSelector, activeTabNameSelector],
|
||||||
|
({ sdxl }, activeTabName) => {
|
||||||
|
return {
|
||||||
|
prompt: sdxl.negativeStylePrompt,
|
||||||
|
activeTabName,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prompt input text area.
|
||||||
|
*/
|
||||||
|
const ParamSDXLNegativeStyleConditioning = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||||
|
|
||||||
|
const handleChangePrompt = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
|
dispatch(setNegativeStylePromptSDXL(e.target.value));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSelectEmbedding = useCallback(
|
||||||
|
(v: string) => {
|
||||||
|
if (!promptRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is where we insert the TI trigger
|
||||||
|
const caret = promptRef.current.selectionStart;
|
||||||
|
|
||||||
|
if (caret === undefined) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let newPrompt = prompt.slice(0, caret);
|
||||||
|
|
||||||
|
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||||
|
newPrompt += '<';
|
||||||
|
}
|
||||||
|
|
||||||
|
newPrompt += `${v}>`;
|
||||||
|
|
||||||
|
// we insert the cursor after the `>`
|
||||||
|
const finalCaretPos = newPrompt.length;
|
||||||
|
|
||||||
|
newPrompt += prompt.slice(caret);
|
||||||
|
|
||||||
|
// must flush dom updates else selection gets reset
|
||||||
|
flushSync(() => {
|
||||||
|
dispatch(setNegativeStylePromptSDXL(newPrompt));
|
||||||
|
});
|
||||||
|
|
||||||
|
// set the caret position to just after the TI trigger
|
||||||
|
promptRef.current.selectionStart = finalCaretPos;
|
||||||
|
promptRef.current.selectionEnd = finalCaretPos;
|
||||||
|
onClose();
|
||||||
|
},
|
||||||
|
[dispatch, onClose, prompt]
|
||||||
|
);
|
||||||
|
|
||||||
|
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
|
||||||
|
|
||||||
|
const handleKeyDown = useCallback(
|
||||||
|
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
|
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
||||||
|
e.preventDefault();
|
||||||
|
dispatch(clampSymmetrySteps());
|
||||||
|
dispatch(userInvoked(activeTabName));
|
||||||
|
}
|
||||||
|
if (isEmbeddingEnabled && e.key === '<') {
|
||||||
|
onOpen();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
|
||||||
|
);
|
||||||
|
|
||||||
|
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
|
||||||
|
// const target = e.target as HTMLTextAreaElement;
|
||||||
|
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
|
||||||
|
// };
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box position="relative">
|
||||||
|
<FormControl>
|
||||||
|
<ParamEmbeddingPopover
|
||||||
|
isOpen={isOpen}
|
||||||
|
onClose={onClose}
|
||||||
|
onSelect={handleSelectEmbedding}
|
||||||
|
>
|
||||||
|
<IAITextarea
|
||||||
|
id="prompt"
|
||||||
|
name="prompt"
|
||||||
|
ref={promptRef}
|
||||||
|
value={prompt}
|
||||||
|
placeholder="Negative Style Prompt"
|
||||||
|
onChange={handleChangePrompt}
|
||||||
|
onKeyDown={handleKeyDown}
|
||||||
|
resize="vertical"
|
||||||
|
fontSize="sm"
|
||||||
|
minH={16}
|
||||||
|
/>
|
||||||
|
</ParamEmbeddingPopover>
|
||||||
|
</FormControl>
|
||||||
|
{!isOpen && isEmbeddingEnabled && (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
insetInlineEnd: 0,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<AddEmbeddingButton onClick={onOpen} />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamSDXLNegativeStyleConditioning;
|
@ -0,0 +1,148 @@
|
|||||||
|
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||||
|
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
|
||||||
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
|
||||||
|
import { userInvoked } from 'app/store/actions';
|
||||||
|
import IAITextarea from 'common/components/IAITextarea';
|
||||||
|
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||||
|
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||||
|
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||||
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import { isEqual } from 'lodash-es';
|
||||||
|
import { flushSync } from 'react-dom';
|
||||||
|
import { setPositiveStylePromptSDXL } from '../store/sdxlSlice';
|
||||||
|
|
||||||
|
const promptInputSelector = createSelector(
|
||||||
|
[stateSelector, activeTabNameSelector],
|
||||||
|
({ sdxl }, activeTabName) => {
|
||||||
|
return {
|
||||||
|
prompt: sdxl.positiveStylePrompt,
|
||||||
|
activeTabName,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
{
|
||||||
|
memoizeOptions: {
|
||||||
|
resultEqualityCheck: isEqual,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prompt input text area.
|
||||||
|
*/
|
||||||
|
const ParamSDXLPositiveStyleConditioning = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
||||||
|
const isReady = useIsReadyToInvoke();
|
||||||
|
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||||
|
|
||||||
|
const handleChangePrompt = useCallback(
|
||||||
|
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
|
dispatch(setPositiveStylePromptSDXL(e.target.value));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleSelectEmbedding = useCallback(
|
||||||
|
(v: string) => {
|
||||||
|
if (!promptRef.current) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is where we insert the TI trigger
|
||||||
|
const caret = promptRef.current.selectionStart;
|
||||||
|
|
||||||
|
if (caret === undefined) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let newPrompt = prompt.slice(0, caret);
|
||||||
|
|
||||||
|
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||||
|
newPrompt += '<';
|
||||||
|
}
|
||||||
|
|
||||||
|
newPrompt += `${v}>`;
|
||||||
|
|
||||||
|
// we insert the cursor after the `>`
|
||||||
|
const finalCaretPos = newPrompt.length;
|
||||||
|
|
||||||
|
newPrompt += prompt.slice(caret);
|
||||||
|
|
||||||
|
// must flush dom updates else selection gets reset
|
||||||
|
flushSync(() => {
|
||||||
|
dispatch(setPositiveStylePromptSDXL(newPrompt));
|
||||||
|
});
|
||||||
|
|
||||||
|
// set the caret position to just after the TI trigger
|
||||||
|
promptRef.current.selectionStart = finalCaretPos;
|
||||||
|
promptRef.current.selectionEnd = finalCaretPos;
|
||||||
|
onClose();
|
||||||
|
},
|
||||||
|
[dispatch, onClose, prompt]
|
||||||
|
);
|
||||||
|
|
||||||
|
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
|
||||||
|
|
||||||
|
const handleKeyDown = useCallback(
|
||||||
|
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||||
|
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
||||||
|
e.preventDefault();
|
||||||
|
dispatch(clampSymmetrySteps());
|
||||||
|
dispatch(userInvoked(activeTabName));
|
||||||
|
}
|
||||||
|
if (isEmbeddingEnabled && e.key === '<') {
|
||||||
|
onOpen();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
|
||||||
|
);
|
||||||
|
|
||||||
|
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
|
||||||
|
// const target = e.target as HTMLTextAreaElement;
|
||||||
|
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
|
||||||
|
// };
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Box position="relative">
|
||||||
|
<FormControl>
|
||||||
|
<ParamEmbeddingPopover
|
||||||
|
isOpen={isOpen}
|
||||||
|
onClose={onClose}
|
||||||
|
onSelect={handleSelectEmbedding}
|
||||||
|
>
|
||||||
|
<IAITextarea
|
||||||
|
id="prompt"
|
||||||
|
name="prompt"
|
||||||
|
ref={promptRef}
|
||||||
|
value={prompt}
|
||||||
|
placeholder="Positive Style Prompt"
|
||||||
|
onChange={handleChangePrompt}
|
||||||
|
onKeyDown={handleKeyDown}
|
||||||
|
resize="vertical"
|
||||||
|
minH={16}
|
||||||
|
/>
|
||||||
|
</ParamEmbeddingPopover>
|
||||||
|
</FormControl>
|
||||||
|
{!isOpen && isEmbeddingEnabled && (
|
||||||
|
<Box
|
||||||
|
sx={{
|
||||||
|
position: 'absolute',
|
||||||
|
top: 0,
|
||||||
|
insetInlineEnd: 0,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<AddEmbeddingButton onClick={onOpen} />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamSDXLPositiveStyleConditioning;
|
@ -0,0 +1,48 @@
|
|||||||
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import ParamSDXLRefinerAestheticScore from './SDXLRefiner/ParamSDXLRefinerAestheticScore';
|
||||||
|
import ParamSDXLRefinerCFGScale from './SDXLRefiner/ParamSDXLRefinerCFGScale';
|
||||||
|
import ParamSDXLRefinerModelSelect from './SDXLRefiner/ParamSDXLRefinerModelSelect';
|
||||||
|
import ParamSDXLRefinerScheduler from './SDXLRefiner/ParamSDXLRefinerScheduler';
|
||||||
|
import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
|
||||||
|
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
|
||||||
|
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => {
|
||||||
|
const { shouldUseSDXLRefiner } = state.sdxl;
|
||||||
|
const { shouldUseSliders } = state.ui;
|
||||||
|
return {
|
||||||
|
activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined,
|
||||||
|
shouldUseSliders,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLRefinerCollapse = () => {
|
||||||
|
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAICollapse label="Refiner" activeLabel={activeLabel}>
|
||||||
|
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||||
|
<ParamUseSDXLRefiner />
|
||||||
|
<ParamSDXLRefinerModelSelect />
|
||||||
|
<Flex gap={2} flexDirection={shouldUseSliders ? 'column' : 'row'}>
|
||||||
|
<ParamSDXLRefinerSteps />
|
||||||
|
<ParamSDXLRefinerCFGScale />
|
||||||
|
</Flex>
|
||||||
|
<ParamSDXLRefinerScheduler />
|
||||||
|
<ParamSDXLRefinerAestheticScore />
|
||||||
|
<ParamSDXLRefinerStart />
|
||||||
|
</Flex>
|
||||||
|
</IAICollapse>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default ParamSDXLRefinerCollapse;
|
@ -0,0 +1,78 @@
|
|||||||
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||||
|
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||||
|
import ParamModelandVAEandScheduler from 'features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler';
|
||||||
|
import ParamSize from 'features/parameters/components/Parameters/Core/ParamSize';
|
||||||
|
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||||
|
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
|
||||||
|
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||||
|
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||||
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { memo } from 'react';
|
||||||
|
import ParamSDXLImg2ImgDenoisingStrength from './ParamSDXLImg2ImgDenoisingStrength';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[uiSelector, generationSelector],
|
||||||
|
(ui, generation) => {
|
||||||
|
const { shouldUseSliders } = ui;
|
||||||
|
const { shouldRandomizeSeed } = generation;
|
||||||
|
|
||||||
|
const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined;
|
||||||
|
|
||||||
|
return { shouldUseSliders, activeLabel };
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const SDXLImageToImageTabCoreParameters = () => {
|
||||||
|
const { shouldUseSliders, activeLabel } = useAppSelector(selector);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAICollapse
|
||||||
|
label={'General'}
|
||||||
|
activeLabel={activeLabel}
|
||||||
|
defaultIsOpen={true}
|
||||||
|
>
|
||||||
|
<Flex
|
||||||
|
sx={{
|
||||||
|
flexDirection: 'column',
|
||||||
|
gap: 3,
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{shouldUseSliders ? (
|
||||||
|
<>
|
||||||
|
<ParamIterations />
|
||||||
|
<ParamSteps />
|
||||||
|
<ParamCFGScale />
|
||||||
|
<ParamModelandVAEandScheduler />
|
||||||
|
<Box pt={2}>
|
||||||
|
<ParamSeedFull />
|
||||||
|
</Box>
|
||||||
|
<ParamSize />
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<Flex gap={3}>
|
||||||
|
<ParamIterations />
|
||||||
|
<ParamSteps />
|
||||||
|
<ParamCFGScale />
|
||||||
|
</Flex>
|
||||||
|
<ParamModelandVAEandScheduler />
|
||||||
|
<Box pt={2}>
|
||||||
|
<ParamSeedFull />
|
||||||
|
</Box>
|
||||||
|
<ParamSize />
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
<ParamSDXLImg2ImgDenoisingStrength />
|
||||||
|
<ImageToImageFit />
|
||||||
|
</Flex>
|
||||||
|
</IAICollapse>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(SDXLImageToImageTabCoreParameters);
|
@ -0,0 +1,28 @@
|
|||||||
|
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||||
|
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||||
|
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||||
|
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||||
|
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||||
|
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||||
|
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||||
|
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||||
|
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||||
|
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
|
||||||
|
|
||||||
|
const SDXLImageToImageTabParameters = () => {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<ParamPositiveConditioning />
|
||||||
|
<ParamSDXLPositiveStyleConditioning />
|
||||||
|
<ParamNegativeConditioning />
|
||||||
|
<ParamSDXLNegativeStyleConditioning />
|
||||||
|
<ProcessButtons />
|
||||||
|
<SDXLImageToImageTabCoreParameters />
|
||||||
|
<ParamSDXLRefinerCollapse />
|
||||||
|
<ParamDynamicPromptsCollapse />
|
||||||
|
<ParamNoiseCollapse />
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SDXLImageToImageTabParameters;
|
@ -0,0 +1,60 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ sdxl, hotkeys }) => {
|
||||||
|
const { refinerAestheticScore } = sdxl;
|
||||||
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
|
return {
|
||||||
|
refinerAestheticScore,
|
||||||
|
shift,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLRefinerAestheticScore = () => {
|
||||||
|
const { refinerAestheticScore, shift } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => dispatch(setRefinerAestheticScore(v)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReset = useCallback(
|
||||||
|
() => dispatch(setRefinerAestheticScore(6)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISlider
|
||||||
|
label="Aesthetic Score"
|
||||||
|
step={shift ? 0.1 : 0.5}
|
||||||
|
min={1}
|
||||||
|
max={10}
|
||||||
|
onChange={handleChange}
|
||||||
|
handleReset={handleReset}
|
||||||
|
value={refinerAestheticScore}
|
||||||
|
sliderNumberInputProps={{ max: 10 }}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
withSliderMarks
|
||||||
|
isInteger={false}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamSDXLRefinerAestheticScore);
|
@ -0,0 +1,75 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ sdxl, ui, hotkeys }) => {
|
||||||
|
const { refinerCFGScale } = sdxl;
|
||||||
|
const { shouldUseSliders } = ui;
|
||||||
|
const { shift } = hotkeys;
|
||||||
|
|
||||||
|
return {
|
||||||
|
refinerCFGScale,
|
||||||
|
shouldUseSliders,
|
||||||
|
shift,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLRefinerCFGScale = () => {
|
||||||
|
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => dispatch(setRefinerCFGScale(v)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReset = useCallback(
|
||||||
|
() => dispatch(setRefinerCFGScale(7)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return shouldUseSliders ? (
|
||||||
|
<IAISlider
|
||||||
|
label={t('parameters.cfgScale')}
|
||||||
|
step={shift ? 0.1 : 0.5}
|
||||||
|
min={1}
|
||||||
|
max={20}
|
||||||
|
onChange={handleChange}
|
||||||
|
handleReset={handleReset}
|
||||||
|
value={refinerCFGScale}
|
||||||
|
sliderNumberInputProps={{ max: 200 }}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
withSliderMarks
|
||||||
|
isInteger={false}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<IAINumberInput
|
||||||
|
label={t('parameters.cfgScale')}
|
||||||
|
step={0.5}
|
||||||
|
min={1}
|
||||||
|
max={200}
|
||||||
|
onChange={handleChange}
|
||||||
|
value={refinerCFGScale}
|
||||||
|
isInteger={false}
|
||||||
|
numberInputFieldProps={{ textAlign: 'center' }}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamSDXLRefinerCFGScale);
|
@ -0,0 +1,111 @@
|
|||||||
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||||
|
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
|
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
(state) => ({ model: state.sdxl.refinerModel }),
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLRefinerModelSelect = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||||
|
|
||||||
|
const { model } = useAppSelector(selector);
|
||||||
|
|
||||||
|
const { data: refinerModels, isLoading } =
|
||||||
|
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!refinerModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(refinerModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [refinerModels]);
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
// TODO: maybe we should just store the full model entity in state?
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() =>
|
||||||
|
refinerModels?.entities[
|
||||||
|
`${model?.base_model}/main/${model?.model_name}`
|
||||||
|
] ?? null,
|
||||||
|
[refinerModels?.entities, model]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleChangeModel = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newModel = modelIdToMainModelParam(v);
|
||||||
|
|
||||||
|
if (!newModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(refinerModelChanged(newModel));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return isLoading ? (
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
label="Refiner Model"
|
||||||
|
placeholder="Loading..."
|
||||||
|
disabled={true}
|
||||||
|
data={[]}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Flex w="100%" alignItems="center" gap={2}>
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label="Refiner Model"
|
||||||
|
value={selectedModel?.id}
|
||||||
|
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||||
|
data={data}
|
||||||
|
error={data.length === 0}
|
||||||
|
disabled={data.length === 0}
|
||||||
|
onChange={handleChangeModel}
|
||||||
|
w="100%"
|
||||||
|
/>
|
||||||
|
{isSyncModelEnabled && (
|
||||||
|
<Box mt={7}>
|
||||||
|
<SyncModelsButton iconMode />
|
||||||
|
</Box>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamSDXLRefinerModelSelect);
|
@ -0,0 +1,65 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||||
|
import {
|
||||||
|
SCHEDULER_LABEL_MAP,
|
||||||
|
SchedulerParam,
|
||||||
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { map } from 'lodash-es';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
stateSelector,
|
||||||
|
({ ui, sdxl }) => {
|
||||||
|
const { refinerScheduler } = sdxl;
|
||||||
|
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||||
|
|
||||||
|
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||||
|
value: name,
|
||||||
|
label: label,
|
||||||
|
group: enabledSchedulers.includes(name as SchedulerParam)
|
||||||
|
? 'Favorites'
|
||||||
|
: undefined,
|
||||||
|
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
|
return {
|
||||||
|
refinerScheduler,
|
||||||
|
data,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLRefinerScheduler = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { refinerScheduler, data } = useAppSelector(selector);
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(setRefinerScheduler(v as SchedulerParam));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSearchableSelect
|
||||||
|
w="100%"
|
||||||
|
label={t('parameters.scheduler')}
|
||||||
|
value={refinerScheduler}
|
||||||
|
data={data}
|
||||||
|
onChange={handleChange}
|
||||||
|
disabled={!isRefinerAvailable}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamSDXLRefinerScheduler);
|
@ -0,0 +1,53 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ sdxl }) => {
|
||||||
|
const { refinerStart } = sdxl;
|
||||||
|
return {
|
||||||
|
refinerStart,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLRefinerStart = () => {
|
||||||
|
const { refinerStart } = useAppSelector(selector);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => dispatch(setRefinerStart(v)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleReset = useCallback(
|
||||||
|
() => dispatch(setRefinerStart(0.7)),
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISlider
|
||||||
|
label="Refiner Start"
|
||||||
|
step={0.01}
|
||||||
|
min={0}
|
||||||
|
max={1}
|
||||||
|
onChange={handleChange}
|
||||||
|
handleReset={handleReset}
|
||||||
|
value={refinerStart}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
withSliderMarks
|
||||||
|
isInteger={false}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamSDXLRefinerStart);
|
@ -0,0 +1,72 @@
|
|||||||
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
|
import { stateSelector } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
|
import IAINumberInput from 'common/components/IAINumberInput';
|
||||||
|
import IAISlider from 'common/components/IAISlider';
|
||||||
|
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { memo, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
|
const selector = createSelector(
|
||||||
|
[stateSelector],
|
||||||
|
({ sdxl, ui }) => {
|
||||||
|
const { refinerSteps } = sdxl;
|
||||||
|
const { shouldUseSliders } = ui;
|
||||||
|
|
||||||
|
return {
|
||||||
|
refinerSteps,
|
||||||
|
shouldUseSliders,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
defaultSelectorOptions
|
||||||
|
);
|
||||||
|
|
||||||
|
const ParamSDXLRefinerSteps = () => {
|
||||||
|
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const handleChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
dispatch(setRefinerSteps(v));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
const handleReset = useCallback(() => {
|
||||||
|
dispatch(setRefinerSteps(20));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
return shouldUseSliders ? (
|
||||||
|
<IAISlider
|
||||||
|
label={t('parameters.steps')}
|
||||||
|
min={1}
|
||||||
|
max={100}
|
||||||
|
step={1}
|
||||||
|
onChange={handleChange}
|
||||||
|
handleReset={handleReset}
|
||||||
|
value={refinerSteps}
|
||||||
|
withInput
|
||||||
|
withReset
|
||||||
|
withSliderMarks
|
||||||
|
sliderNumberInputProps={{ max: 500 }}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<IAINumberInput
|
||||||
|
label={t('parameters.steps')}
|
||||||
|
min={1}
|
||||||
|
max={500}
|
||||||
|
step={1}
|
||||||
|
onChange={handleChange}
|
||||||
|
value={refinerSteps}
|
||||||
|
numberInputFieldProps={{ textAlign: 'center' }}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamSDXLRefinerSteps);
|
@ -0,0 +1,28 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import IAISwitch from 'common/components/IAISwitch';
|
||||||
|
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { ChangeEvent } from 'react';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
|
export default function ParamUseSDXLRefiner() {
|
||||||
|
const shouldUseSDXLRefiner = useAppSelector(
|
||||||
|
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
|
||||||
|
);
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const handleUseSDXLRefinerChange = (e: ChangeEvent<HTMLInputElement>) => {
|
||||||
|
dispatch(setShouldUseSDXLRefiner(e.target.checked));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAISwitch
|
||||||
|
label="Use Refiner"
|
||||||
|
isChecked={shouldUseSDXLRefiner}
|
||||||
|
onChange={handleUseSDXLRefinerChange}
|
||||||
|
isDisabled={!isRefinerAvailable}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,27 @@
|
|||||||
|
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||||
|
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||||
|
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||||
|
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||||
|
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||||
|
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
|
||||||
|
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||||
|
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||||
|
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||||
|
|
||||||
|
const SDXLTextToImageTabParameters = () => {
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<ParamPositiveConditioning />
|
||||||
|
<ParamSDXLPositiveStyleConditioning />
|
||||||
|
<ParamNegativeConditioning />
|
||||||
|
<ParamSDXLNegativeStyleConditioning />
|
||||||
|
<ProcessButtons />
|
||||||
|
<TextToImageTabCoreParameters />
|
||||||
|
<ParamSDXLRefinerCollapse />
|
||||||
|
<ParamDynamicPromptsCollapse />
|
||||||
|
<ParamNoiseCollapse />
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default SDXLTextToImageTabParameters;
|
89
invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts
Normal file
89
invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||||
|
import {
|
||||||
|
MainModelParam,
|
||||||
|
NegativeStylePromptSDXLParam,
|
||||||
|
PositiveStylePromptSDXLParam,
|
||||||
|
SchedulerParam,
|
||||||
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { MainModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
type SDXLInitialState = {
|
||||||
|
positiveStylePrompt: PositiveStylePromptSDXLParam;
|
||||||
|
negativeStylePrompt: NegativeStylePromptSDXLParam;
|
||||||
|
shouldUseSDXLRefiner: boolean;
|
||||||
|
sdxlImg2ImgDenoisingStrength: number;
|
||||||
|
refinerModel: MainModelField | null;
|
||||||
|
refinerSteps: number;
|
||||||
|
refinerCFGScale: number;
|
||||||
|
refinerScheduler: SchedulerParam;
|
||||||
|
refinerAestheticScore: number;
|
||||||
|
refinerStart: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
const sdxlInitialState: SDXLInitialState = {
|
||||||
|
positiveStylePrompt: '',
|
||||||
|
negativeStylePrompt: '',
|
||||||
|
shouldUseSDXLRefiner: false,
|
||||||
|
sdxlImg2ImgDenoisingStrength: 0.7,
|
||||||
|
refinerModel: null,
|
||||||
|
refinerSteps: 20,
|
||||||
|
refinerCFGScale: 7.5,
|
||||||
|
refinerScheduler: 'euler',
|
||||||
|
refinerAestheticScore: 6,
|
||||||
|
refinerStart: 0.7,
|
||||||
|
};
|
||||||
|
|
||||||
|
const sdxlSlice = createSlice({
|
||||||
|
name: 'sdxl',
|
||||||
|
initialState: sdxlInitialState,
|
||||||
|
reducers: {
|
||||||
|
setPositiveStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
||||||
|
state.positiveStylePrompt = action.payload;
|
||||||
|
},
|
||||||
|
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
||||||
|
state.negativeStylePrompt = action.payload;
|
||||||
|
},
|
||||||
|
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
|
||||||
|
state.shouldUseSDXLRefiner = action.payload;
|
||||||
|
},
|
||||||
|
setSDXLImg2ImgDenoisingStrength: (state, action: PayloadAction<number>) => {
|
||||||
|
state.sdxlImg2ImgDenoisingStrength = action.payload;
|
||||||
|
},
|
||||||
|
refinerModelChanged: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<MainModelParam | null>
|
||||||
|
) => {
|
||||||
|
state.refinerModel = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerSteps: (state, action: PayloadAction<number>) => {
|
||||||
|
state.refinerSteps = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerCFGScale: (state, action: PayloadAction<number>) => {
|
||||||
|
state.refinerCFGScale = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerScheduler: (state, action: PayloadAction<SchedulerParam>) => {
|
||||||
|
state.refinerScheduler = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerAestheticScore: (state, action: PayloadAction<number>) => {
|
||||||
|
state.refinerAestheticScore = action.payload;
|
||||||
|
},
|
||||||
|
setRefinerStart: (state, action: PayloadAction<number>) => {
|
||||||
|
state.refinerStart = action.payload;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
export const {
|
||||||
|
setPositiveStylePromptSDXL,
|
||||||
|
setNegativeStylePromptSDXL,
|
||||||
|
setShouldUseSDXLRefiner,
|
||||||
|
setSDXLImg2ImgDenoisingStrength,
|
||||||
|
refinerModelChanged,
|
||||||
|
setRefinerSteps,
|
||||||
|
setRefinerCFGScale,
|
||||||
|
setRefinerScheduler,
|
||||||
|
setRefinerAestheticScore,
|
||||||
|
setRefinerStart,
|
||||||
|
} = sdxlSlice.actions;
|
||||||
|
|
||||||
|
export default sdxlSlice.reducer;
|
@ -16,7 +16,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
||||||
import { configSelector } from 'features/system/store/configSelectors';
|
import { configSelector } from 'features/system/store/configSelectors';
|
||||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||||
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
||||||
import { ResourceKey } from 'i18next';
|
import { ResourceKey } from 'i18next';
|
||||||
import { isEqual } from 'lodash-es';
|
import { isEqual } from 'lodash-es';
|
||||||
@ -172,13 +172,22 @@ const InvokeTabs = () => {
|
|||||||
const { ref: galleryPanelRef, minSizePct: galleryMinSizePct } =
|
const { ref: galleryPanelRef, minSizePct: galleryMinSizePct } =
|
||||||
useMinimumPanelSize(MIN_GALLERY_WIDTH, DEFAULT_GALLERY_PCT, 'app');
|
useMinimumPanelSize(MIN_GALLERY_WIDTH, DEFAULT_GALLERY_PCT, 'app');
|
||||||
|
|
||||||
|
const handleTabChange = useCallback(
|
||||||
|
(index: number) => {
|
||||||
|
const activeTabName = tabMap[index];
|
||||||
|
if (!activeTabName) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(setActiveTab(activeTabName));
|
||||||
|
},
|
||||||
|
[dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Tabs
|
<Tabs
|
||||||
defaultIndex={activeTab}
|
defaultIndex={activeTab}
|
||||||
index={activeTab}
|
index={activeTab}
|
||||||
onChange={(index: number) => {
|
onChange={handleTabChange}
|
||||||
dispatch(setActiveTab(index));
|
|
||||||
}}
|
|
||||||
sx={{
|
sx={{
|
||||||
flexGrow: 1,
|
flexGrow: 1,
|
||||||
gap: 4,
|
gap: 4,
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import { Box, Flex } from '@chakra-ui/react';
|
import { Box, Flex } from '@chakra-ui/react';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||||
import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay';
|
import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay';
|
||||||
|
import SDXLImageToImageTabParameters from 'features/sdxl/components/SDXLImageToImageTabParameters';
|
||||||
import { memo, useCallback, useRef } from 'react';
|
import { memo, useCallback, useRef } from 'react';
|
||||||
import {
|
import {
|
||||||
ImperativePanelGroupHandle,
|
ImperativePanelGroupHandle,
|
||||||
@ -16,6 +18,7 @@ import ImageToImageTabParameters from './ImageToImageTabParameters';
|
|||||||
const ImageToImageTab = () => {
|
const ImageToImageTab = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||||
|
const model = useAppSelector((state: RootState) => state.generation.model);
|
||||||
|
|
||||||
const handleDoubleClickHandle = useCallback(() => {
|
const handleDoubleClickHandle = useCallback(() => {
|
||||||
if (!panelGroupRef.current) {
|
if (!panelGroupRef.current) {
|
||||||
@ -28,7 +31,11 @@ const ImageToImageTab = () => {
|
|||||||
return (
|
return (
|
||||||
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
|
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
|
||||||
<ParametersPinnedWrapper>
|
<ParametersPinnedWrapper>
|
||||||
<ImageToImageTabParameters />
|
{model && model.base_model === 'sdxl' ? (
|
||||||
|
<SDXLImageToImageTabParameters />
|
||||||
|
) : (
|
||||||
|
<ImageToImageTabParameters />
|
||||||
|
)}
|
||||||
</ParametersPinnedWrapper>
|
</ParametersPinnedWrapper>
|
||||||
<Box sx={{ w: 'full', h: 'full' }}>
|
<Box sx={{ w: 'full', h: 'full' }}>
|
||||||
<PanelGroup
|
<PanelGroup
|
||||||
|
@ -16,6 +16,7 @@ import {
|
|||||||
useImportMainModelsMutation,
|
useImportMainModelsMutation,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
|
||||||
export default function FoundModelsList() {
|
export default function FoundModelsList() {
|
||||||
const searchFolder = useAppSelector(
|
const searchFolder = useAppSelector(
|
||||||
@ -24,7 +25,7 @@ export default function FoundModelsList() {
|
|||||||
const [nameFilter, setNameFilter] = useState<string>('');
|
const [nameFilter, setNameFilter] = useState<string>('');
|
||||||
|
|
||||||
// Get paths of models that are already installed
|
// Get paths of models that are already installed
|
||||||
const { data: installedModels } = useGetMainModelsQuery();
|
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||||
|
|
||||||
// Get all model paths from a given directory
|
// Get all model paths from a given directory
|
||||||
const { foundModels, alreadyInstalled, filteredModels } =
|
const { foundModels, alreadyInstalled, filteredModels } =
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
|||||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||||
import IAISlider from 'common/components/IAISlider';
|
import IAISlider from 'common/components/IAISlider';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { pickBy } from 'lodash-es';
|
import { pickBy } from 'lodash-es';
|
||||||
import { useMemo, useState } from 'react';
|
import { useMemo, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
import {
|
import {
|
||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useMergeMainModelsMutation,
|
useMergeMainModelsMutation,
|
||||||
@ -32,7 +33,7 @@ export default function MergeModelsPanel() {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data } = useGetMainModelsQuery();
|
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||||
|
|
||||||
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
||||||
|
|
||||||
|
@ -8,10 +8,11 @@ import {
|
|||||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||||
import ModelList from './ModelManagerPanel/ModelList';
|
import ModelList from './ModelManagerPanel/ModelList';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
|
||||||
export default function ModelManagerPanel() {
|
export default function ModelManagerPanel() {
|
||||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||||
const { model } = useGetMainModelsQuery(undefined, {
|
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||||
}),
|
}),
|
||||||
|
@ -11,6 +11,7 @@ import {
|
|||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
} from 'services/api/endpoints/models';
|
} from 'services/api/endpoints/models';
|
||||||
import ModelListItem from './ModelListItem';
|
import ModelListItem from './ModelListItem';
|
||||||
|
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||||
|
|
||||||
type ModelListProps = {
|
type ModelListProps = {
|
||||||
selectedModelId: string | undefined;
|
selectedModelId: string | undefined;
|
||||||
@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => {
|
|||||||
const [modelFormatFilter, setModelFormatFilter] =
|
const [modelFormatFilter, setModelFormatFilter] =
|
||||||
useState<ModelFormat>('images');
|
useState<ModelFormat>('images');
|
||||||
|
|
||||||
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
|
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
|
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||||
selectFromResult: ({ data }) => ({
|
selectFromResult: ({ data }) => ({
|
||||||
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
||||||
}),
|
}),
|
||||||
|
@ -1,14 +1,22 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex } from '@chakra-ui/react';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import TextToImageSDXLTabParameters from 'features/sdxl/components/SDXLTextToImageTabParameters';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
||||||
import TextToImageTabMain from './TextToImageTabMain';
|
import TextToImageTabMain from './TextToImageTabMain';
|
||||||
import TextToImageTabParameters from './TextToImageTabParameters';
|
import TextToImageTabParameters from './TextToImageTabParameters';
|
||||||
|
|
||||||
const TextToImageTab = () => {
|
const TextToImageTab = () => {
|
||||||
|
const model = useAppSelector((state: RootState) => state.generation.model);
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
|
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
|
||||||
<ParametersPinnedWrapper>
|
<ParametersPinnedWrapper>
|
||||||
<TextToImageTabParameters />
|
{model && model.base_model === 'sdxl' ? (
|
||||||
|
<TextToImageSDXLTabParameters />
|
||||||
|
) : (
|
||||||
|
<TextToImageTabParameters />
|
||||||
|
)}
|
||||||
</ParametersPinnedWrapper>
|
</ParametersPinnedWrapper>
|
||||||
<TextToImageTabMain />
|
<TextToImageTabMain />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -26,7 +26,7 @@ export const uiSlice = createSlice({
|
|||||||
name: 'ui',
|
name: 'ui',
|
||||||
initialState: initialUIState,
|
initialState: initialUIState,
|
||||||
reducers: {
|
reducers: {
|
||||||
setActiveTab: (state, action: PayloadAction<number | InvokeTabName>) => {
|
setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
|
||||||
setActiveTabReducer(state, action.payload);
|
setActiveTabReducer(state, action.payload);
|
||||||
},
|
},
|
||||||
setShouldPinParametersPanel: (state, action: PayloadAction<boolean>) => {
|
setShouldPinParametersPanel: (state, action: PayloadAction<boolean>) => {
|
||||||
|
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import { BaseModelType } from './types';
|
||||||
|
|
||||||
|
export const ALL_BASE_MODELS: BaseModelType[] = [
|
||||||
|
'sd-1',
|
||||||
|
'sd-2',
|
||||||
|
'sdxl',
|
||||||
|
'sdxl-refiner',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const NON_REFINER_BASE_MODELS: BaseModelType[] = [
|
||||||
|
'sd-1',
|
||||||
|
'sd-2',
|
||||||
|
'sdxl',
|
||||||
|
];
|
||||||
|
|
||||||
|
export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner'];
|
@ -144,8 +144,19 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
|||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
getMainModels: build.query<
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
|
EntityState<MainModelConfigEntity>,
|
||||||
|
BaseModelType[]
|
||||||
|
>({
|
||||||
|
query: (base_models) => {
|
||||||
|
const params = {
|
||||||
|
model_type: 'main',
|
||||||
|
base_models,
|
||||||
|
};
|
||||||
|
|
||||||
|
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||||
|
return `models/?${query}`;
|
||||||
|
},
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result, error, arg) => {
|
||||||
const tags: ApiFullTagDescription[] = [
|
const tags: ApiFullTagDescription[] = [
|
||||||
{ type: 'MainModel', id: LIST_TAG },
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
@ -187,7 +198,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
body: body,
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
importMainModels: build.mutation<
|
importMainModels: build.mutation<
|
||||||
ImportMainModelResponse,
|
ImportMainModelResponse,
|
||||||
@ -200,7 +214,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
body: body,
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||||
query: ({ body }) => {
|
query: ({ body }) => {
|
||||||
@ -210,7 +227,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
body: body,
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
deleteMainModels: build.mutation<
|
deleteMainModels: build.mutation<
|
||||||
DeleteMainModelResponse,
|
DeleteMainModelResponse,
|
||||||
@ -222,7 +242,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
convertMainModels: build.mutation<
|
convertMainModels: build.mutation<
|
||||||
ConvertMainModelResponse,
|
ConvertMainModelResponse,
|
||||||
@ -235,7 +258,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
params: params,
|
params: params,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||||
query: ({ base_model, body }) => {
|
query: ({ base_model, body }) => {
|
||||||
@ -245,7 +271,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
body: body,
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||||
query: () => {
|
query: () => {
|
||||||
@ -254,7 +283,10 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
method: 'POST',
|
method: 'POST',
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [
|
||||||
|
{ type: 'MainModel', id: LIST_TAG },
|
||||||
|
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||||
|
@ -0,0 +1,12 @@
|
|||||||
|
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
|
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
export const useIsRefinerAvailable = () => {
|
||||||
|
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
|
||||||
|
selectFromResult: ({ data }) => ({
|
||||||
|
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
return isRefinerAvailable;
|
||||||
|
};
|
@ -1014,6 +1014,11 @@ export type components = {
|
|||||||
* @description The LoRAs used for inference
|
* @description The LoRAs used for inference
|
||||||
*/
|
*/
|
||||||
loras: (components["schemas"]["LoRAMetadataField"])[];
|
loras: (components["schemas"]["LoRAMetadataField"])[];
|
||||||
|
/**
|
||||||
|
* Vae
|
||||||
|
* @description The VAE used for decoding, if the main model's default was not used
|
||||||
|
*/
|
||||||
|
vae?: components["schemas"]["VAEModelField"];
|
||||||
/**
|
/**
|
||||||
* Strength
|
* Strength
|
||||||
* @description The strength used for latents-to-latents
|
* @description The strength used for latents-to-latents
|
||||||
@ -1025,10 +1030,45 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
init_image?: string;
|
init_image?: string;
|
||||||
/**
|
/**
|
||||||
* Vae
|
* Positive Style Prompt
|
||||||
* @description The VAE used for decoding, if the main model's default was not used
|
* @description The positive style prompt parameter
|
||||||
*/
|
*/
|
||||||
vae?: components["schemas"]["VAEModelField"];
|
positive_style_prompt?: string;
|
||||||
|
/**
|
||||||
|
* Negative Style Prompt
|
||||||
|
* @description The negative style prompt parameter
|
||||||
|
*/
|
||||||
|
negative_style_prompt?: string;
|
||||||
|
/**
|
||||||
|
* Refiner Model
|
||||||
|
* @description The SDXL Refiner model used
|
||||||
|
*/
|
||||||
|
refiner_model?: components["schemas"]["MainModelField"];
|
||||||
|
/**
|
||||||
|
* Refiner Cfg Scale
|
||||||
|
* @description The classifier-free guidance scale parameter used for the refiner
|
||||||
|
*/
|
||||||
|
refiner_cfg_scale?: number;
|
||||||
|
/**
|
||||||
|
* Refiner Steps
|
||||||
|
* @description The number of steps used for the refiner
|
||||||
|
*/
|
||||||
|
refiner_steps?: number;
|
||||||
|
/**
|
||||||
|
* Refiner Scheduler
|
||||||
|
* @description The scheduler used for the refiner
|
||||||
|
*/
|
||||||
|
refiner_scheduler?: string;
|
||||||
|
/**
|
||||||
|
* Refiner Aesthetic Store
|
||||||
|
* @description The aesthetic score used for the refiner
|
||||||
|
*/
|
||||||
|
refiner_aesthetic_store?: number;
|
||||||
|
/**
|
||||||
|
* Refiner Start
|
||||||
|
* @description The start value used for refiner denoising
|
||||||
|
*/
|
||||||
|
refiner_start?: number;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* CvInpaintInvocation
|
* CvInpaintInvocation
|
||||||
@ -3268,6 +3308,46 @@ export type components = {
|
|||||||
* @description The VAE used for decoding, if the main model's default was not used
|
* @description The VAE used for decoding, if the main model's default was not used
|
||||||
*/
|
*/
|
||||||
vae?: components["schemas"]["VAEModelField"];
|
vae?: components["schemas"]["VAEModelField"];
|
||||||
|
/**
|
||||||
|
* Positive Style Prompt
|
||||||
|
* @description The positive style prompt parameter
|
||||||
|
*/
|
||||||
|
positive_style_prompt?: string;
|
||||||
|
/**
|
||||||
|
* Negative Style Prompt
|
||||||
|
* @description The negative style prompt parameter
|
||||||
|
*/
|
||||||
|
negative_style_prompt?: string;
|
||||||
|
/**
|
||||||
|
* Refiner Model
|
||||||
|
* @description The SDXL Refiner model used
|
||||||
|
*/
|
||||||
|
refiner_model?: components["schemas"]["MainModelField"];
|
||||||
|
/**
|
||||||
|
* Refiner Cfg Scale
|
||||||
|
* @description The classifier-free guidance scale parameter used for the refiner
|
||||||
|
*/
|
||||||
|
refiner_cfg_scale?: number;
|
||||||
|
/**
|
||||||
|
* Refiner Steps
|
||||||
|
* @description The number of steps used for the refiner
|
||||||
|
*/
|
||||||
|
refiner_steps?: number;
|
||||||
|
/**
|
||||||
|
* Refiner Scheduler
|
||||||
|
* @description The scheduler used for the refiner
|
||||||
|
*/
|
||||||
|
refiner_scheduler?: string;
|
||||||
|
/**
|
||||||
|
* Refiner Aesthetic Store
|
||||||
|
* @description The aesthetic score used for the refiner
|
||||||
|
*/
|
||||||
|
refiner_aesthetic_store?: number;
|
||||||
|
/**
|
||||||
|
* Refiner Start
|
||||||
|
* @description The start value used for refiner denoising
|
||||||
|
*/
|
||||||
|
refiner_start?: number;
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* MetadataAccumulatorOutput
|
* MetadataAccumulatorOutput
|
||||||
@ -5355,6 +5435,12 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
image?: components["schemas"]["ImageField"];
|
image?: components["schemas"]["ImageField"];
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* StableDiffusion1ModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||||
/**
|
/**
|
||||||
* StableDiffusion2ModelFormat
|
* StableDiffusion2ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -5367,12 +5453,6 @@ export type components = {
|
|||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||||
/**
|
|
||||||
* StableDiffusion1ModelFormat
|
|
||||||
* @description An enumeration.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
|
||||||
};
|
};
|
||||||
responses: never;
|
responses: never;
|
||||||
parameters: never;
|
parameters: never;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user