feat: Add SDXL To Linear UI (#3973)

## What type of PR is this? (check all applicable)

- [x] Feature


## Have you discussed this change with the InvokeAI team?
- [x] Yes

## Description

This PR adds support for SDXL Models in the Linear UI

### DONE

- SDXL Base Text To Image Support
- SDXL Base Image To Image Support
- SDXL Refiner Support
- SDXL Relevant UI


## [optional] Are there any post deployment tasks we need to perform?

Double check to ensure nothing major changed with 1.0 -- In any case
those changes would be backend related mostly. If Refiner is scrapped
for 1.0 models, then we simply disable the Refiner Graph.
This commit is contained in:
blessedcoolant 2023-07-26 17:05:39 +12:00 committed by GitHub
commit 531bc40d3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 2671 additions and 110 deletions

View File

@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled): def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(), **clip_field.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -196,6 +196,7 @@ class SDXLPromptInvocationBase:
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:
@ -240,16 +241,16 @@ class SDXLPromptInvocationBase:
def run_clip_compel(self, context, clip_field, prompt, get_pooled): def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(), **clip_field.tokenizer.dict(), context=context,
) )
text_encoder_info = context.services.model_manager.get_model( text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(), **clip_field.text_encoder.dict(), context=context,
) )
def _lora_loader(): def _lora_loader():
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"})) **lora.dict(exclude={"weight"}), context=context)
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
@ -265,6 +266,7 @@ class SDXLPromptInvocationBase:
model_name=name, model_name=name,
base_model=clip_field.text_encoder.base_model, base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion, model_type=ModelType.TextualInversion,
context=context,
).context.model ).context.model
) )
except ModelNotFoundException: except ModelNotFoundException:

View File

@ -2,16 +2,19 @@ from typing import Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (BaseInvocation, from invokeai.app.invocations.baseinvocation import (
BaseInvocationOutput, InvocationConfig, BaseInvocation,
InvocationContext) BaseInvocationOutput,
InvocationConfig,
InvocationContext,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import (LoRAModelField, MainModelField, from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
VAEModelField)
class LoRAMetadataField(BaseModel): class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI.""" """LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model") lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model") weight: float = Field(description="The weight of the LoRA model")
@ -19,7 +22,9 @@ class LoRAMetadataField(BaseModel):
class CoreMetadata(BaseModel): class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI.""" """Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(description="The generation mode that output this image",) generation_mode: str = Field(
description="The generation mode that output this image",
)
positive_prompt: str = Field(description="The positive prompt parameter") positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter") negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter") width: int = Field(description="The width parameter")
@ -29,10 +34,20 @@ class CoreMetadata(BaseModel):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter") cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference") steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference") scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",) clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference") model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference") controlnets: list[ControlField] = Field(
description="The ControlNets used for inference"
)
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
# Latents-to-Latents
strength: Union[float, None] = Field( strength: Union[float, None] = Field(
default=None, default=None,
description="The strength used for latents-to-latents", description="The strength used for latents-to-latents",
@ -40,9 +55,34 @@ class CoreMetadata(BaseModel):
init_image: Union[str, None] = Field( init_image: Union[str, None] = Field(
default=None, description="The name of the initial image" default=None, description="The name of the initial image"
) )
vae: Union[VAEModelField, None] = Field(
# SDXL
positive_style_prompt: Union[str, None] = Field(
default=None, description="The positive style prompt parameter"
)
negative_style_prompt: Union[str, None] = Field(
default=None, description="The negative style prompt parameter"
)
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(
default=None, description="The SDXL Refiner model used"
)
refiner_cfg_scale: Union[float, None] = Field(
default=None, default=None,
description="The VAE used for decoding, if the main model's default was not used", description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(
default=None, description="The number of steps used for the refiner"
)
refiner_scheduler: Union[str, None] = Field(
default=None, description="The scheduler used for the refiner"
)
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(
default=None, description="The start value used for refiner denoising"
) )
@ -71,7 +111,9 @@ class MetadataAccumulatorInvocation(BaseInvocation):
type: Literal["metadata_accumulator"] = "metadata_accumulator" type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field(description="The generation mode that output this image",) generation_mode: str = Field(
description="The generation mode that output this image",
)
positive_prompt: str = Field(description="The positive prompt parameter") positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter") negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter") width: int = Field(description="The width parameter")
@ -81,9 +123,13 @@ class MetadataAccumulatorInvocation(BaseInvocation):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter") cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference") steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference") scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(description="The number of skipped CLIP layers",) clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
model: MainModelField = Field(description="The main model used for inference") model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference") controlnets: list[ControlField] = Field(
description="The ControlNets used for inference"
)
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field( strength: Union[float, None] = Field(
default=None, default=None,
@ -97,36 +143,44 @@ class MetadataAccumulatorInvocation(BaseInvocation):
description="The VAE used for decoding, if the main model's default was not used", description="The VAE used for decoding, if the main model's default was not used",
) )
# SDXL
positive_style_prompt: Union[str, None] = Field(
default=None, description="The positive style prompt parameter"
)
negative_style_prompt: Union[str, None] = Field(
default=None, description="The negative style prompt parameter"
)
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(
default=None, description="The SDXL Refiner model used"
)
refiner_cfg_scale: Union[float, None] = Field(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(
default=None, description="The number of steps used for the refiner"
)
refiner_scheduler: Union[str, None] = Field(
default=None, description="The scheduler used for the refiner"
)
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(
default=None, description="The start value used for refiner denoising"
)
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"title": "Metadata Accumulator", "title": "Metadata Accumulator",
"tags": ["image", "metadata", "generation"] "tags": ["image", "metadata", "generation"],
}, },
} }
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput: def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object""" """Collects and outputs a CoreMetadata object"""
return MetadataAccumulatorOutput( return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
metadata=CoreMetadata(
generation_mode=self.generation_mode,
positive_prompt=self.positive_prompt,
negative_prompt=self.negative_prompt,
width=self.width,
height=self.height,
seed=self.seed,
rand_device=self.rand_device,
cfg_scale=self.cfg_scale,
steps=self.steps,
scheduler=self.scheduler,
model=self.model,
strength=self.strength,
init_image=self.init_image,
vae=self.vae,
controlnets=self.controlnets,
loras=self.loras,
clip_skip=self.clip_skip,
)
)

View File

@ -138,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"ui": { "ui": {
"title": "SDXL Refiner Model Loader", "title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"], "tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "model"}, "type_hints": {"model": "refiner_model"},
}, },
} }
@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context
) )
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None
@ -463,8 +463,8 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
latents: Optional[LatentsField] = Field(description="Initial latents") latents: Optional[LatentsField] = Field(description="Initial latents")
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="") denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="") denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
@ -549,13 +549,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
num_inference_steps = num_inference_steps - t_start num_inference_steps = num_inference_steps - t_start
# apply noise(if provided) # apply noise(if provided)
if self.noise is not None: if self.noise is not None and timesteps.shape[0] > 0:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1]) latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise del noise
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict() **self.unet.unet.dict(), context=context,
) )
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None

View File

@ -65,18 +65,19 @@ import { addGeneratorProgressEventListener as addGeneratorProgressListener } fro
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete'; import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete'; import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError'; import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted'; import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad'; import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed'; import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed'; import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested'; import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -201,3 +202,6 @@ addFirstListImagesListener();
// Ad-hoc upscale workflwo // Ad-hoc upscale workflwo
addUpscaleRequestedListener(); addUpscaleRequestedListener();
// Tab Change
addTabChangedListener();

View File

@ -9,13 +9,19 @@ import {
zMainModel, zMainModel,
zVaeModel, zVaeModel,
} from 'features/parameters/types/parameterSchemas'; } from 'features/parameters/types/parameterSchemas';
import {
refinerModelChanged,
setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es'; import { forEach, some } from 'lodash-es';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..'; import { startAppListening } from '..';
export const addModelsLoadedListener = () => { export const addModelsLoadedListener = () => {
startAppListening({ startAppListening({
matcher: modelsApi.endpoints.getMainModels.matchFulfilled, predicate: (state, action) =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one // models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models'); const log = logger('models');
@ -59,6 +65,54 @@ export const addModelsLoadedListener = () => {
dispatch(modelChanged(result.data)); dispatch(modelChanged(result.data));
}, },
}); });
startAppListening({
predicate: (state, action) =>
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
action.meta.arg.originalArgs.includes('sdxl-refiner'),
effect: async (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info(
{ models: action.payload.entities },
`SDXL Refiner models loaded (${action.payload.ids.length})`
);
const currentModel = getState().sdxl.refinerModel;
const isCurrentModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === currentModel?.model_name &&
m?.base_model === currentModel?.base_model
);
if (isCurrentModelAvailable) {
return;
}
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No models loaded at all
dispatch(refinerModelChanged(null));
dispatch(setShouldUseSDXLRefiner(false));
return;
}
const result = zMainModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse SDXL Refiner Model'
);
return;
}
dispatch(refinerModelChanged(result.data));
},
});
startAppListening({ startAppListening({
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled, matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {

View File

@ -3,6 +3,11 @@ import { modelsApi } from 'services/api/endpoints/models';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { appSocketConnected, socketConnected } from 'services/events/actions';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
import {
ALL_BASE_MODELS,
NON_REFINER_BASE_MODELS,
REFINER_BASE_MODELS,
} from 'services/api/constants';
export const addSocketConnectedEventListener = () => { export const addSocketConnectedEventListener = () => {
startAppListening({ startAppListening({
@ -24,7 +29,11 @@ export const addSocketConnectedEventListener = () => {
dispatch(appSocketConnected(action.payload)); dispatch(appSocketConnected(action.payload));
// update all server state // update all server state
dispatch(modelsApi.endpoints.getMainModels.initiate()); dispatch(modelsApi.endpoints.getMainModels.initiate(REFINER_BASE_MODELS));
dispatch(
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
);
dispatch(modelsApi.endpoints.getMainModels.initiate(ALL_BASE_MODELS));
dispatch(modelsApi.endpoints.getControlNetModels.initiate()); dispatch(modelsApi.endpoints.getControlNetModels.initiate());
dispatch(modelsApi.endpoints.getLoRAModels.initiate()); dispatch(modelsApi.endpoints.getLoRAModels.initiate());
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate()); dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());

View File

@ -21,7 +21,10 @@ export const addInvocationStartedEventListener = () => {
return; return;
} }
log.debug(action.payload, 'Invocation started'); log.debug(
action.payload,
`Invocation started (${action.payload.data.node.type})`
);
dispatch(appSocketInvocationStarted(action.payload)); dispatch(appSocketInvocationStarted(action.payload));
}, },
}); });

View File

@ -0,0 +1,56 @@
import { modelChanged } from 'features/parameters/store/generationSlice';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import {
MainModelConfigEntity,
modelsApi,
} from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addTabChangedListener = () => {
startAppListening({
actionCreator: setActiveTab,
effect: (action, { getState, dispatch }) => {
const activeTabName = action.payload;
if (activeTabName === 'unifiedCanvas') {
// grab the models from RTK Query cache
const { data } = modelsApi.endpoints.getMainModels.select(
NON_REFINER_BASE_MODELS
)(getState());
if (!data) {
// no models yet, so we can't do anything
dispatch(modelChanged(null));
return;
}
// need to filter out all the invalid canvas models (currently, this is just sdxl)
const validCanvasModels: MainModelConfigEntity[] = [];
forEach(data.entities, (entity) => {
if (!entity) {
return;
}
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
validCanvasModels.push(entity);
}
});
// this could still be undefined even tho TS doesn't say so
const firstValidCanvasModel = validCanvasModels[0];
if (!firstValidCanvasModel) {
// uh oh, we have no models that are valid for canvas
dispatch(modelChanged(null));
return;
}
// only store the model name and base model in redux
const { base_model, model_name } = firstValidCanvasModel;
dispatch(modelChanged({ base_model, model_name }));
}
},
});
};

View File

@ -3,6 +3,7 @@ import { userInvoked } from 'app/store/actions';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions'; import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph'; import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
import { sessionReadyToInvoke } from 'features/system/store/actions'; import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionCreated } from 'services/api/thunks/session'; import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -14,8 +15,16 @@ export const addUserInvokedImageToImageListener = () => {
effect: async (action, { getState, dispatch, take }) => { effect: async (action, { getState, dispatch, take }) => {
const log = logger('session'); const log = logger('session');
const state = getState(); const state = getState();
const model = state.generation.model;
let graph;
if (model && model.base_model === 'sdxl') {
graph = buildLinearSDXLImageToImageGraph(state);
} else {
graph = buildLinearImageToImageGraph(state);
}
const graph = buildLinearImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph)); dispatch(imageToImageGraphBuilt(graph));
log.debug({ graph: parseify(graph) }, 'Image to Image graph built'); log.debug({ graph: parseify(graph) }, 'Image to Image graph built');

View File

@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { textToImageGraphBuilt } from 'features/nodes/store/actions'; import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph'; import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
import { sessionReadyToInvoke } from 'features/system/store/actions'; import { sessionReadyToInvoke } from 'features/system/store/actions';
import { sessionCreated } from 'services/api/thunks/session'; import { sessionCreated } from 'services/api/thunks/session';
@ -14,8 +15,15 @@ export const addUserInvokedTextToImageListener = () => {
effect: async (action, { getState, dispatch, take }) => { effect: async (action, { getState, dispatch, take }) => {
const log = logger('session'); const log = logger('session');
const state = getState(); const state = getState();
const model = state.generation.model;
const graph = buildLinearTextToImageGraph(state); let graph;
if (model && model.base_model === 'sdxl') {
graph = buildLinearSDXLTextToImageGraph(state);
} else {
graph = buildLinearTextToImageGraph(state);
}
dispatch(textToImageGraphBuilt(graph)); dispatch(textToImageGraphBuilt(graph));

View File

@ -15,6 +15,7 @@ import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice'; import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import sdxlReducer from 'features/sdxl/store/sdxlSlice';
import configReducer from 'features/system/store/configSlice'; import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice'; import systemReducer from 'features/system/store/systemSlice';
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice'; import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
@ -47,6 +48,7 @@ const allReducers = {
imageDeletion: imageDeletionReducer, imageDeletion: imageDeletionReducer,
lora: loraReducer, lora: loraReducer,
modelmanager: modelmanagerReducer, modelmanager: modelmanagerReducer,
sdxl: sdxlReducer,
[api.reducerPath]: api.reducer, [api.reducerPath]: api.reducer,
}; };
@ -58,6 +60,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'canvas', 'canvas',
'gallery', 'gallery',
'generation', 'generation',
'sdxl',
'nodes', 'nodes',
'postprocessing', 'postprocessing',
'system', 'system',

View File

@ -6,6 +6,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { modelsApi } from '../../services/api/endpoints/models'; import { modelsApi } from '../../services/api/endpoints/models';
import { ALL_BASE_MODELS } from 'services/api/constants';
const readinessSelector = createSelector( const readinessSelector = createSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
@ -24,7 +25,7 @@ const readinessSelector = createSelector(
} }
const { isSuccess: mainModelsSuccessfullyLoaded } = const { isSuccess: mainModelsSuccessfullyLoaded } =
modelsApi.endpoints.getMainModels.select()(state); modelsApi.endpoints.getMainModels.select(ALL_BASE_MODELS)(state);
if (!mainModelsSuccessfullyLoaded) { if (!mainModelsSuccessfullyLoaded) {
isReady = false; isReady = false;
reasonsWhyNotReady.push('Models are not loaded'); reasonsWhyNotReady.push('Models are not loaded');

View File

@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent';
import UnetInputFieldComponent from './fields/UnetInputFieldComponent'; import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent'; import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
type InputFieldComponentProps = { type InputFieldComponentProps = {
nodeId: string; nodeId: string;
@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'refiner_model' && template.type === 'refiner_model') {
return (
<RefinerModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'vae_model' && template.type === 'vae_model') { if (type === 'vae_model' && template.type === 'vae_model') {
return ( return (
<VaeModelInputFieldComponent <VaeModelInputFieldComponent

View File

@ -14,6 +14,7 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus'; import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
@ -27,7 +28,9 @@ const ModelInputFieldComponent = (
const { t } = useTranslation(); const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery(); const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!mainModels) {

View File

@ -0,0 +1,120 @@
import { Box, Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
RefinerModelInputFieldTemplate,
RefinerModelInputFieldValue,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { FieldComponentProps } from './types';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
const RefinerModelInputFieldComponent = (
props: FieldComponentProps<
RefinerModelInputFieldValue,
RefinerModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {
return [];
}
const data: SelectItem[] = [];
forEach(refinerModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [refinerModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
refinerModels?.entities[
`${field.value?.base_model}/main/${field.value?.model_name}`
] ?? null,
[field.value?.base_model, field.value?.model_name, refinerModels?.entities]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: newModel,
})
);
},
[dispatch, field.name, nodeId]
);
return isLoading ? (
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<Flex w="100%" alignItems="center" gap={2}>
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
}
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
/>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton iconMode />
</Box>
)}
</Flex>
);
};
export default memo(RefinerModelInputFieldComponent);

View File

@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
ClipField: 'clip', ClipField: 'clip',
VaeField: 'vae', VaeField: 'vae',
model: 'model', model: 'model',
refiner_model: 'refiner_model',
vae_model: 'vae_model', vae_model: 'vae_model',
lora_model: 'lora_model', lora_model: 'lora_model',
controlnet_model: 'controlnet_model', controlnet_model: 'controlnet_model',
@ -120,6 +121,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Model', title: 'Model',
description: 'Models are models.', description: 'Models are models.',
}, },
refiner_model: {
color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'),
title: 'Refiner Model',
description: 'Models are models.',
},
vae_model: { vae_model: {
color: 'teal', color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'), colorCssVar: getColorTokenCssVariable('teal'),

View File

@ -70,6 +70,7 @@ export type FieldType =
| 'vae' | 'vae'
| 'control' | 'control'
| 'model' | 'model'
| 'refiner_model'
| 'vae_model' | 'vae_model'
| 'lora_model' | 'lora_model'
| 'controlnet_model' | 'controlnet_model'
@ -100,6 +101,7 @@ export type InputFieldValue =
| ControlInputFieldValue | ControlInputFieldValue
| EnumInputFieldValue | EnumInputFieldValue
| MainModelInputFieldValue | MainModelInputFieldValue
| RefinerModelInputFieldValue
| VaeModelInputFieldValue | VaeModelInputFieldValue
| LoRAModelInputFieldValue | LoRAModelInputFieldValue
| ControlNetModelInputFieldValue | ControlNetModelInputFieldValue
@ -128,6 +130,7 @@ export type InputFieldTemplate =
| ControlInputFieldTemplate | ControlInputFieldTemplate
| EnumInputFieldTemplate | EnumInputFieldTemplate
| ModelInputFieldTemplate | ModelInputFieldTemplate
| RefinerModelInputFieldTemplate
| VaeModelInputFieldTemplate | VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate | LoRAModelInputFieldTemplate
| ControlNetModelInputFieldTemplate | ControlNetModelInputFieldTemplate
@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & {
value?: MainModelParam; value?: MainModelParam;
}; };
export type RefinerModelInputFieldValue = FieldValueBase & {
type: 'refiner_model';
value?: MainModelParam;
};
export type VaeModelInputFieldValue = FieldValueBase & { export type VaeModelInputFieldValue = FieldValueBase & {
type: 'vae_model'; type: 'vae_model';
value?: VaeModelParam; value?: VaeModelParam;
@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'model'; type: 'model';
}; };
export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'refiner_model';
};
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
default: string; default: string;
type: 'vae_model'; type: 'vae_model';

View File

@ -22,6 +22,7 @@ import {
LoRAModelInputFieldTemplate, LoRAModelInputFieldTemplate,
ModelInputFieldTemplate, ModelInputFieldTemplate,
OutputFieldTemplate, OutputFieldTemplate,
RefinerModelInputFieldTemplate,
StringInputFieldTemplate, StringInputFieldTemplate,
TypeHints, TypeHints,
UNetInputFieldTemplate, UNetInputFieldTemplate,
@ -178,6 +179,21 @@ const buildModelInputFieldTemplate = ({
return template; return template;
}; };
const buildRefinerModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
const template: RefinerModelInputFieldTemplate = {
...baseField,
type: 'refiner_model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVaeModelInputFieldTemplate = ({ const buildVaeModelInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -492,6 +508,9 @@ export const buildInputFieldTemplate = (
if (['model'].includes(fieldType)) { if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField }); return buildModelInputFieldTemplate({ schemaObject, baseField });
} }
if (['refiner_model'].includes(fieldType)) {
return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
}
if (['vae_model'].includes(fieldType)) { if (['vae_model'].includes(fieldType)) {
return buildVaeModelInputFieldTemplate({ schemaObject, baseField }); return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
} }

View File

@ -76,6 +76,10 @@ export const buildInputFieldValue = (
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'refiner_model') {
fieldValue.value = undefined;
}
if (template.type === 'vae_model') { if (template.type === 'vae_model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }

View File

@ -0,0 +1,186 @@
import { RootState } from 'app/store/store';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
SDXL_LATENTS_TO_LATENTS,
SDXL_MODEL_LOADER,
SDXL_REFINER_LATENTS_TO_LATENTS,
SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING,
} from './constants';
export const addSDXLRefinerToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { positivePrompt, negativePrompt } = state.generation;
const {
refinerModel,
refinerAestheticScore,
positiveStylePrompt,
negativeStylePrompt,
refinerSteps,
refinerScheduler,
refinerCFGScale,
refinerStart,
} = state.sdxl;
if (!refinerModel) return;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (metadataAccumulator) {
metadataAccumulator.refiner_model = refinerModel;
metadataAccumulator.refiner_aesthetic_store = refinerAestheticScore;
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
metadataAccumulator.refiner_scheduler = refinerScheduler;
metadataAccumulator.refiner_start = refinerStart;
metadataAccumulator.refiner_steps = refinerSteps;
}
// Unplug SDXL Latents Generation To Latents To Image
graph.edges = graph.edges.filter(
(e) =>
!(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field))
);
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === SDXL_MODEL_LOADER &&
['vae'].includes(e.source.field)
)
);
// connect the VAE back to the i2l, which we just removed in the filter
// but only if we are doing l2l
if (baseNodeId === SDXL_LATENTS_TO_LATENTS) {
graph.edges.push({
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
});
}
graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
type: 'sdxl_refiner_model_loader',
id: SDXL_REFINER_MODEL_LOADER,
model: refinerModel,
};
graph.nodes[SDXL_REFINER_POSITIVE_CONDITIONING] = {
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_POSITIVE_CONDITIONING,
style: `${positivePrompt} ${positiveStylePrompt}`,
aesthetic_score: refinerAestheticScore,
};
graph.nodes[SDXL_REFINER_NEGATIVE_CONDITIONING] = {
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
style: `${negativePrompt} ${negativeStylePrompt}`,
aesthetic_score: refinerAestheticScore,
};
graph.nodes[SDXL_REFINER_LATENTS_TO_LATENTS] = {
type: 'l2l_sdxl',
id: SDXL_REFINER_LATENTS_TO_LATENTS,
cfg_scale: refinerCFGScale,
steps: refinerSteps / (1 - Math.min(refinerStart, 0.99)),
scheduler: refinerScheduler,
denoising_start: refinerStart,
denoising_end: 1,
};
graph.edges.push(
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: baseNodeId,
field: 'latents',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'latents',
},
},
{
source: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
}
);
};

View File

@ -46,6 +46,7 @@ export const buildLinearImageToImageGraph = (
clipSkip, clipSkip,
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
vaePrecision,
} = state.generation; } = state.generation;
// TODO: add batch functionality // TODO: add batch functionality
@ -113,6 +114,7 @@ export const buildLinearImageToImageGraph = (
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
}, },
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
@ -129,6 +131,7 @@ export const buildLinearImageToImageGraph = (
// image: { // image: {
// image_name: initialImage.image_name, // image_name: initialImage.image_name,
// }, // },
fp32: vaePrecision === 'fp32' ? true : false,
}, },
}, },
edges: [ edges: [

View File

@ -0,0 +1,369 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import {
ImageResizeInvocation,
ImageToLatentsInvocation,
} from 'services/api/types';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import {
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
RESIZE,
SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_LATENTS_TO_LATENTS,
SDXL_MODEL_LOADER,
} from './constants';
/**
* Builds the Image to Image tab graph.
*/
export const buildLinearSDXLImageToImageGraph = (
state: RootState
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
initialImage,
shouldFitToWidthHeight,
width,
height,
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
} = state.generation;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
sdxlImg2ImgDenoisingStrength: strength,
} = state.sdxl;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
if (!initialImage) {
log.error('No initial image found in state');
throw new Error('No initial image found in state');
}
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: SDXL_IMAGE_TO_IMAGE_GRAPH,
nodes: {
[SDXL_MODEL_LOADER]: {
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
},
[SDXL_LATENTS_TO_LATENTS]: {
type: 'l2l_sdxl',
id: SDXL_LATENTS_TO_LATENTS,
cfg_scale,
scheduler,
steps,
denoising_start: shouldUseSDXLRefiner
? Math.min(refinerStart, 1 - strength)
: 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
id: IMAGE_TO_LATENTS,
// must be set manually later, bc `fit` parameter may require a resize node inserted
// image: {
// image_name: initialImage.image_name,
// },
fp32: vaePrecision === 'fp32' ? true : false,
},
},
edges: [
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
{
source: {
node_id: IMAGE_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'latents',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_LATENTS_TO_LATENTS,
field: 'negative_conditioning',
},
},
],
};
// handle `fit`
if (
shouldFitToWidthHeight &&
(initialImage.width !== width || initialImage.height !== height)
) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image
const resizeNode: ImageResizeInvocation = {
id: RESIZE,
type: 'img_resize',
image: {
image_name: initialImage.imageName,
},
is_intermediate: true,
width,
height,
};
graph.nodes[RESIZE] = resizeNode;
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
graph.edges.push({
source: { node_id: RESIZE, field: 'image' },
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'image',
},
});
// The `RESIZE` node also passes its width and height to `NOISE`
graph.edges.push({
source: { node_id: RESIZE, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: RESIZE, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
} else {
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
image_name: initialImage.imageName,
};
// Pass the image's dimensions to the `NOISE` node
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
destination: {
node_id: NOISE,
field: 'width',
},
});
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
destination: {
node_id: NOISE,
field: 'height',
},
});
}
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'sdxl_img2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined,
controlnets: [],
loras: [],
clip_skip: clipSkip,
strength: strength,
init_image: initialImage.imageName,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS);
}
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
return graph;
};

View File

@ -0,0 +1,251 @@
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
SDXL_MODEL_LOADER,
SDXL_TEXT_TO_IMAGE_GRAPH,
SDXL_TEXT_TO_LATENTS,
} from './constants';
export const buildLinearSDXLTextToImageGraph = (
state: RootState
): NonNullableGraph => {
const log = logger('nodes');
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
width,
height,
clipSkip,
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
} = state.generation;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
} = state.sdxl;
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
* the `fit` param. These are added to the graph at the end.
*/
// copy-pasted graph from node editor, filled in with state values & friendly node ids
const graph: NonNullableGraph = {
id: SDXL_TEXT_TO_IMAGE_GRAPH,
nodes: {
[SDXL_MODEL_LOADER]: {
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
width,
height,
use_cpu,
},
[SDXL_TEXT_TO_LATENTS]: {
type: 't2l_sdxl',
id: SDXL_TEXT_TO_LATENTS,
cfg_scale,
scheduler,
steps,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
},
},
edges: [
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip',
},
},
{
source: {
node_id: SDXL_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: NOISE,
field: 'noise',
},
destination: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'noise',
},
},
{
source: {
node_id: SDXL_TEXT_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
},
],
};
// add metadata accumulator, which is only mostly populated - some fields are added later
graph.nodes[METADATA_ACCUMULATOR] = {
id: METADATA_ACCUMULATOR,
type: 'metadata_accumulator',
generation_mode: 'sdxl_txt2img',
cfg_scale,
height,
width,
positive_prompt: '', // set in addDynamicPromptsToGraph
negative_prompt: negativePrompt,
model,
seed: 0, // set in addDynamicPromptsToGraph
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
scheduler,
vae: undefined,
controlnets: [],
loras: [],
clip_skip: clipSkip,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'metadata',
},
});
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);
}
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
return graph;
};

View File

@ -34,6 +34,7 @@ export const buildLinearTextToImageGraph = (
clipSkip, clipSkip,
shouldUseCpuNoise, shouldUseCpuNoise,
shouldUseNoiseSettings, shouldUseNoiseSettings,
vaePrecision,
} = state.generation; } = state.generation;
const use_cpu = shouldUseNoiseSettings const use_cpu = shouldUseNoiseSettings
@ -95,6 +96,7 @@ export const buildLinearTextToImageGraph = (
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
fp32: vaePrecision === 'fp32' ? true : false,
}, },
}, },
edges: [ edges: [

View File

@ -23,8 +23,19 @@ export const METADATA_ACCUMULATOR = 'metadata_accumulator';
export const REALESRGAN = 'esrgan'; export const REALESRGAN = 'esrgan';
export const DIVIDE = 'divide'; export const DIVIDE = 'divide';
export const SCALE = 'scale_image'; export const SCALE = 'scale_image';
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
export const SDXL_TEXT_TO_LATENTS = 't2l_sdxl';
export const SDXL_LATENTS_TO_LATENTS = 'l2l_sdxl';
export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader';
export const SDXL_REFINER_POSITIVE_CONDITIONING =
'sdxl_refiner_positive_conditioning';
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_LATENTS_TO_LATENTS = 'l2l_sdxl_refiner';
// friendly graph ids // friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph'; export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
export const SDXL_TEXT_TO_IMAGE_GRAPH = 'sdxl_text_to_image_graph';
export const SDXL_IMAGE_TO_IMAGE_GRAPH = 'sxdl_image_to_image_graph';
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph'; export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
export const INPAINT_GRAPH = 'inpaint_graph'; export const INPAINT_GRAPH = 'inpaint_graph';

View File

@ -4,6 +4,7 @@ import { memo } from 'react';
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect'; import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect'; import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
import ParamScheduler from './ParamScheduler'; import ParamScheduler from './ParamScheduler';
import ParamVAEPrecision from '../VAEModel/ParamVAEPrecision';
const ParamModelandVAEandScheduler = () => { const ParamModelandVAEandScheduler = () => {
const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled; const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled;
@ -13,16 +14,15 @@ const ParamModelandVAEandScheduler = () => {
<Box w="full"> <Box w="full">
<ParamMainModelSelect /> <ParamMainModelSelect />
</Box> </Box>
<Flex gap={3} w="full"> <Box w="full">
{isVaeEnabled && ( <ParamScheduler />
<Box w="full"> </Box>
<ParamVAEModelSelect /> {isVaeEnabled && (
</Box> <Flex w="full" gap={3}>
)} <ParamVAEModelSelect />
<Box w="full"> <ParamVAEPrecision />
<ParamScheduler /> </Flex>
</Box> )}
</Flex>
</Flex> </Flex>
); );
}; };

View File

@ -13,7 +13,9 @@ import { modelSelected } from 'features/parameters/store/actions';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton'; import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus'; import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
@ -29,8 +31,12 @@ const ParamMainModelSelect = () => {
const { model } = useAppSelector(selector); const { model } = useAppSelector(selector);
const { data: mainModels, isLoading } = useGetMainModelsQuery();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled; const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { data: mainModels, isLoading } = useGetMainModelsQuery(
NON_REFINER_BASE_MODELS
);
const activeTabName = useAppSelector(activeTabNameSelector);
const data = useMemo(() => { const data = useMemo(() => {
if (!mainModels) { if (!mainModels) {
@ -40,7 +46,10 @@ const ParamMainModelSelect = () => {
const data: SelectItem[] = []; const data: SelectItem[] = [];
forEach(mainModels.entities, (model, id) => { forEach(mainModels.entities, (model, id) => {
if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) { if (
!model ||
(activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl')
) {
return; return;
} }
@ -52,7 +61,7 @@ const ParamMainModelSelect = () => {
}); });
return data; return data;
}, [mainModels]); }, [mainModels, activeTabName]);
// grab the full model entity from the RTK Query cache // grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state? // TODO: maybe we should just store the full model entity in state?
@ -88,7 +97,7 @@ const ParamMainModelSelect = () => {
data={[]} data={[]}
/> />
) : ( ) : (
<Flex w="100%" alignItems="center" gap={2}> <Flex w="100%" alignItems="center" gap={3}>
<IAIMantineSearchableSelect <IAIMantineSearchableSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.model')} label={t('modelManager.model')}

View File

@ -32,11 +32,6 @@ export default function ParamSeed() {
isInvalid={seed < 0 && shouldGenerateVariations} isInvalid={seed < 0 && shouldGenerateVariations}
onChange={handleChangeSeed} onChange={handleChangeSeed}
value={seed} value={seed}
formControlProps={{
display: 'flex',
alignItems: 'center',
gap: 3, // really this should work with 2 but seems to need to be 3 to match gap 2?
}}
/> />
); );
} }

View File

@ -6,7 +6,7 @@ import ParamSeedRandomize from './ParamSeedRandomize';
const ParamSeedFull = () => { const ParamSeedFull = () => {
return ( return (
<Flex sx={{ gap: 4, alignItems: 'center' }}> <Flex sx={{ gap: 3, alignItems: 'flex-end' }}>
<ParamSeed /> <ParamSeed />
<ParamSeedShuffle /> <ParamSeedShuffle />
<ParamSeedRandomize /> <ParamSeedRandomize />

View File

@ -0,0 +1,46 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { vaePrecisionChanged } from 'features/parameters/store/generationSlice';
import { PrecisionParam } from 'features/parameters/types/parameterSchemas';
import { memo, useCallback } from 'react';
const selector = createSelector(
stateSelector,
({ generation }) => {
const { vaePrecision } = generation;
return { vaePrecision };
},
defaultSelectorOptions
);
const DATA = ['fp16', 'fp32'];
const ParamVAEModelSelect = () => {
const dispatch = useAppDispatch();
const { vaePrecision } = useAppSelector(selector);
const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(vaePrecisionChanged(v as PrecisionParam));
},
[dispatch]
);
return (
<IAIMantineSelect
label="VAE Precision"
value={vaePrecision}
data={DATA}
onChange={handleChange}
/>
);
};
export default memo(ParamVAEModelSelect);

View File

@ -11,6 +11,7 @@ import {
MainModelParam, MainModelParam,
NegativePromptParam, NegativePromptParam,
PositivePromptParam, PositivePromptParam,
PrecisionParam,
SchedulerParam, SchedulerParam,
SeedParam, SeedParam,
StepsParam, StepsParam,
@ -51,6 +52,7 @@ export interface GenerationState {
verticalSymmetrySteps: number; verticalSymmetrySteps: number;
model: MainModelField | null; model: MainModelField | null;
vae: VaeModelParam | null; vae: VaeModelParam | null;
vaePrecision: PrecisionParam;
seamlessXAxis: boolean; seamlessXAxis: boolean;
seamlessYAxis: boolean; seamlessYAxis: boolean;
clipSkip: number; clipSkip: number;
@ -89,6 +91,7 @@ export const initialGenerationState: GenerationState = {
verticalSymmetrySteps: 0, verticalSymmetrySteps: 0,
model: null, model: null,
vae: null, vae: null,
vaePrecision: 'fp32',
seamlessXAxis: false, seamlessXAxis: false,
seamlessYAxis: false, seamlessYAxis: false,
clipSkip: 0, clipSkip: 0,
@ -241,6 +244,9 @@ export const generationSlice = createSlice({
// null is a valid VAE! // null is a valid VAE!
state.vae = action.payload; state.vae = action.payload;
}, },
vaePrecisionChanged: (state, action: PayloadAction<PrecisionParam>) => {
state.vaePrecision = action.payload;
},
setClipSkip: (state, action: PayloadAction<number>) => { setClipSkip: (state, action: PayloadAction<number>) => {
state.clipSkip = action.payload; state.clipSkip = action.payload;
}, },
@ -327,6 +333,7 @@ export const {
shouldUseCpuNoiseChanged, shouldUseCpuNoiseChanged,
setShouldShowAdvancedOptions, setShouldShowAdvancedOptions,
setAspectRatio, setAspectRatio,
vaePrecisionChanged,
} = generationSlice.actions; } = generationSlice.actions;
export default generationSlice.reducer; export default generationSlice.reducer;

View File

@ -42,6 +42,42 @@ export const isValidNegativePrompt = (
val: unknown val: unknown
): val is NegativePromptParam => zNegativePrompt.safeParse(val).success; ): val is NegativePromptParam => zNegativePrompt.safeParse(val).success;
/**
* Zod schema for SDXL positive style prompt parameter
*/
export const zPositiveStylePromptSDXL = z.string();
/**
* Type alias for SDXL positive style prompt parameter, inferred from its zod schema
*/
export type PositiveStylePromptSDXLParam = z.infer<
typeof zPositiveStylePromptSDXL
>;
/**
* Validates/type-guards a value as a SDXL positive style prompt parameter
*/
export const isValidSDXLPositiveStylePrompt = (
val: unknown
): val is PositiveStylePromptSDXLParam =>
zPositiveStylePromptSDXL.safeParse(val).success;
/**
* Zod schema for SDXL negative style prompt parameter
*/
export const zNegativeStylePromptSDXL = z.string();
/**
* Type alias for SDXL negative style prompt parameter, inferred from its zod schema
*/
export type NegativeStylePromptSDXLParam = z.infer<
typeof zNegativeStylePromptSDXL
>;
/**
* Validates/type-guards a value as a SDXL negative style prompt parameter
*/
export const isValidSDXLNegativeStylePrompt = (
val: unknown
): val is NegativeStylePromptSDXLParam =>
zNegativeStylePromptSDXL.safeParse(val).success;
/** /**
* Zod schema for steps parameter * Zod schema for steps parameter
*/ */
@ -260,6 +296,20 @@ export type StrengthParam = z.infer<typeof zStrength>;
export const isValidStrength = (val: unknown): val is StrengthParam => export const isValidStrength = (val: unknown): val is StrengthParam =>
zStrength.safeParse(val).success; zStrength.safeParse(val).success;
/**
* Zod schema for a precision parameter
*/
export const zPrecision = z.enum(['fp16', 'fp32']);
/**
* Type alias for precision parameter, inferred from its zod schema
*/
export type PrecisionParam = z.infer<typeof zPrecision>;
/**
* Validates/type-guards a value as a precision parameter
*/
export const isValidPrecision = (val: unknown): val is PrecisionParam =>
zPrecision.safeParse(val).success;
// /** // /**
// * Zod schema for BaseModelType // * Zod schema for BaseModelType
// */ // */

View File

@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { setSDXLImg2ImgDenoisingStrength } from '../store/sdxlSlice';
const selector = createSelector(
[stateSelector],
({ sdxl }) => {
const { sdxlImg2ImgDenoisingStrength } = sdxl;
return {
sdxlImg2ImgDenoisingStrength,
};
},
defaultSelectorOptions
);
const ParamSDXLImg2ImgDenoisingStrength = () => {
const { sdxlImg2ImgDenoisingStrength } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => dispatch(setSDXLImg2ImgDenoisingStrength(v)),
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(setSDXLImg2ImgDenoisingStrength(0.7));
}, [dispatch]);
return (
<IAISlider
label={`${t('parameters.denoisingStrength')}`}
step={0.01}
min={0}
max={1}
onChange={handleChange}
handleReset={handleReset}
value={sdxlImg2ImgDenoisingStrength}
isInteger={false}
withInput
withSliderMarks
withReset
/>
);
};
export default memo(ParamSDXLImg2ImgDenoisingStrength);

View File

@ -0,0 +1,149 @@
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { setNegativeStylePromptSDXL } from '../store/sdxlSlice';
const promptInputSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ sdxl }, activeTabName) => {
return {
prompt: sdxl.negativeStylePrompt,
activeTabName,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Prompt input text area.
*/
const ParamSDXLNegativeStyleConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setNegativeStylePromptSDXL(e.target.value));
},
[dispatch]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setNegativeStylePromptSDXL(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (isEmbeddingEnabled && e.key === '<') {
onOpen();
}
},
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box position="relative">
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder="Negative Style Prompt"
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
fontSize="sm"
minH={16}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};
export default ParamSDXLNegativeStyleConditioning;

View File

@ -0,0 +1,148 @@
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
import { createSelector } from '@reduxjs/toolkit';
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { userInvoked } from 'app/store/actions';
import IAITextarea from 'common/components/IAITextarea';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { isEqual } from 'lodash-es';
import { flushSync } from 'react-dom';
import { setPositiveStylePromptSDXL } from '../store/sdxlSlice';
const promptInputSelector = createSelector(
[stateSelector, activeTabNameSelector],
({ sdxl }, activeTabName) => {
return {
prompt: sdxl.positiveStylePrompt,
activeTabName,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
/**
* Prompt input text area.
*/
const ParamSDXLPositiveStyleConditioning = () => {
const dispatch = useAppDispatch();
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
const isReady = useIsReadyToInvoke();
const promptRef = useRef<HTMLTextAreaElement>(null);
const { isOpen, onClose, onOpen } = useDisclosure();
const handleChangePrompt = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPositiveStylePromptSDXL(e.target.value));
},
[dispatch]
);
const handleSelectEmbedding = useCallback(
(v: string) => {
if (!promptRef.current) {
return;
}
// this is where we insert the TI trigger
const caret = promptRef.current.selectionStart;
if (caret === undefined) {
return;
}
let newPrompt = prompt.slice(0, caret);
if (newPrompt[newPrompt.length - 1] !== '<') {
newPrompt += '<';
}
newPrompt += `${v}>`;
// we insert the cursor after the `>`
const finalCaretPos = newPrompt.length;
newPrompt += prompt.slice(caret);
// must flush dom updates else selection gets reset
flushSync(() => {
dispatch(setPositiveStylePromptSDXL(newPrompt));
});
// set the caret position to just after the TI trigger
promptRef.current.selectionStart = finalCaretPos;
promptRef.current.selectionEnd = finalCaretPos;
onClose();
},
[dispatch, onClose, prompt]
);
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
e.preventDefault();
dispatch(clampSymmetrySteps());
dispatch(userInvoked(activeTabName));
}
if (isEmbeddingEnabled && e.key === '<') {
onOpen();
}
},
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
);
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
// const target = e.target as HTMLTextAreaElement;
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
// };
return (
<Box position="relative">
<FormControl>
<ParamEmbeddingPopover
isOpen={isOpen}
onClose={onClose}
onSelect={handleSelectEmbedding}
>
<IAITextarea
id="prompt"
name="prompt"
ref={promptRef}
value={prompt}
placeholder="Positive Style Prompt"
onChange={handleChangePrompt}
onKeyDown={handleKeyDown}
resize="vertical"
minH={16}
/>
</ParamEmbeddingPopover>
</FormControl>
{!isOpen && isEmbeddingEnabled && (
<Box
sx={{
position: 'absolute',
top: 0,
insetInlineEnd: 0,
}}
>
<AddEmbeddingButton onClick={onOpen} />
</Box>
)}
</Box>
);
};
export default ParamSDXLPositiveStyleConditioning;

View File

@ -0,0 +1,48 @@
import { Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import ParamSDXLRefinerAestheticScore from './SDXLRefiner/ParamSDXLRefinerAestheticScore';
import ParamSDXLRefinerCFGScale from './SDXLRefiner/ParamSDXLRefinerCFGScale';
import ParamSDXLRefinerModelSelect from './SDXLRefiner/ParamSDXLRefinerModelSelect';
import ParamSDXLRefinerScheduler from './SDXLRefiner/ParamSDXLRefinerScheduler';
import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
const selector = createSelector(
stateSelector,
(state) => {
const { shouldUseSDXLRefiner } = state.sdxl;
const { shouldUseSliders } = state.ui;
return {
activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined,
shouldUseSliders,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerCollapse = () => {
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
return (
<IAICollapse label="Refiner" activeLabel={activeLabel}>
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<ParamUseSDXLRefiner />
<ParamSDXLRefinerModelSelect />
<Flex gap={2} flexDirection={shouldUseSliders ? 'column' : 'row'}>
<ParamSDXLRefinerSteps />
<ParamSDXLRefinerCFGScale />
</Flex>
<ParamSDXLRefinerScheduler />
<ParamSDXLRefinerAestheticScore />
<ParamSDXLRefinerStart />
</Flex>
</IAICollapse>
);
};
export default ParamSDXLRefinerCollapse;

View File

@ -0,0 +1,78 @@
import { Box, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
import ParamModelandVAEandScheduler from 'features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler';
import ParamSize from 'features/parameters/components/Parameters/Core/ParamSize';
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
import ParamSDXLImg2ImgDenoisingStrength from './ParamSDXLImg2ImgDenoisingStrength';
const selector = createSelector(
[uiSelector, generationSelector],
(ui, generation) => {
const { shouldUseSliders } = ui;
const { shouldRandomizeSeed } = generation;
const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined;
return { shouldUseSliders, activeLabel };
},
defaultSelectorOptions
);
const SDXLImageToImageTabCoreParameters = () => {
const { shouldUseSliders, activeLabel } = useAppSelector(selector);
return (
<IAICollapse
label={'General'}
activeLabel={activeLabel}
defaultIsOpen={true}
>
<Flex
sx={{
flexDirection: 'column',
gap: 3,
}}
>
{shouldUseSliders ? (
<>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
<ParamModelandVAEandScheduler />
<Box pt={2}>
<ParamSeedFull />
</Box>
<ParamSize />
</>
) : (
<>
<Flex gap={3}>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
</Flex>
<ParamModelandVAEandScheduler />
<Box pt={2}>
<ParamSeedFull />
</Box>
<ParamSize />
</>
)}
<ParamSDXLImg2ImgDenoisingStrength />
<ImageToImageFit />
</Flex>
</IAICollapse>
);
};
export default memo(SDXLImageToImageTabCoreParameters);

View File

@ -0,0 +1,28 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
const SDXLImageToImageTabParameters = () => {
return (
<>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ProcessButtons />
<SDXLImageToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
);
};
export default SDXLImageToImageTabParameters;

View File

@ -0,0 +1,60 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl, hotkeys }) => {
const { refinerAestheticScore } = sdxl;
const { shift } = hotkeys;
return {
refinerAestheticScore,
shift,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerAestheticScore = () => {
const { refinerAestheticScore, shift } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerAestheticScore(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setRefinerAestheticScore(6)),
[dispatch]
);
return (
<IAISlider
label="Aesthetic Score"
step={shift ? 0.1 : 0.5}
min={1}
max={10}
onChange={handleChange}
handleReset={handleReset}
value={refinerAestheticScore}
sliderNumberInputProps={{ max: 10 }}
withInput
withReset
withSliderMarks
isInteger={false}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerAestheticScore);

View File

@ -0,0 +1,75 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl, ui, hotkeys }) => {
const { refinerCFGScale } = sdxl;
const { shouldUseSliders } = ui;
const { shift } = hotkeys;
return {
refinerCFGScale,
shouldUseSliders,
shift,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerCFGScale = () => {
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerCFGScale(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setRefinerCFGScale(7)),
[dispatch]
);
return shouldUseSliders ? (
<IAISlider
label={t('parameters.cfgScale')}
step={shift ? 0.1 : 0.5}
min={1}
max={20}
onChange={handleChange}
handleReset={handleReset}
value={refinerCFGScale}
sliderNumberInputProps={{ max: 200 }}
withInput
withReset
withSliderMarks
isInteger={false}
isDisabled={!isRefinerAvailable}
/>
) : (
<IAINumberInput
label={t('parameters.cfgScale')}
step={0.5}
min={1}
max={200}
onChange={handleChange}
value={refinerCFGScale}
isInteger={false}
numberInputFieldProps={{ textAlign: 'center' }}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerCFGScale);

View File

@ -0,0 +1,111 @@
import { Box, Flex } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector(
stateSelector,
(state) => ({ model: state.sdxl.refinerModel }),
defaultSelectorOptions
);
const ParamSDXLRefinerModelSelect = () => {
const dispatch = useAppDispatch();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
const { model } = useAppSelector(selector);
const { data: refinerModels, isLoading } =
useGetMainModelsQuery(REFINER_BASE_MODELS);
const data = useMemo(() => {
if (!refinerModels) {
return [];
}
const data: SelectItem[] = [];
forEach(refinerModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [refinerModels]);
// grab the full model entity from the RTK Query cache
// TODO: maybe we should just store the full model entity in state?
const selectedModel = useMemo(
() =>
refinerModels?.entities[
`${model?.base_model}/main/${model?.model_name}`
] ?? null,
[refinerModels?.entities, model]
);
const handleChangeModel = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newModel = modelIdToMainModelParam(v);
if (!newModel) {
return;
}
dispatch(refinerModelChanged(newModel));
},
[dispatch]
);
return isLoading ? (
<IAIMantineSearchableSelect
label="Refiner Model"
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<Flex w="100%" alignItems="center" gap={2}>
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label="Refiner Model"
value={selectedModel?.id}
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
data={data}
error={data.length === 0}
disabled={data.length === 0}
onChange={handleChangeModel}
w="100%"
/>
{isSyncModelEnabled && (
<Box mt={7}>
<SyncModelsButton iconMode />
</Box>
)}
</Flex>
);
};
export default memo(ParamSDXLRefinerModelSelect);

View File

@ -0,0 +1,65 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
SCHEDULER_LABEL_MAP,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
stateSelector,
({ ui, sdxl }) => {
const { refinerScheduler } = sdxl;
const { favoriteSchedulers: enabledSchedulers } = ui;
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
value: name,
label: label,
group: enabledSchedulers.includes(name as SchedulerParam)
? 'Favorites'
: undefined,
})).sort((a, b) => a.label.localeCompare(b.label));
return {
refinerScheduler,
data,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerScheduler = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { refinerScheduler, data } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const handleChange = useCallback(
(v: string | null) => {
if (!v) {
return;
}
dispatch(setRefinerScheduler(v as SchedulerParam));
},
[dispatch]
);
return (
<IAIMantineSearchableSelect
w="100%"
label={t('parameters.scheduler')}
value={refinerScheduler}
data={data}
onChange={handleChange}
disabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerScheduler);

View File

@ -0,0 +1,53 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISlider from 'common/components/IAISlider';
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl }) => {
const { refinerStart } = sdxl;
return {
refinerStart,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerStart = () => {
const { refinerStart } = useAppSelector(selector);
const dispatch = useAppDispatch();
const isRefinerAvailable = useIsRefinerAvailable();
const handleChange = useCallback(
(v: number) => dispatch(setRefinerStart(v)),
[dispatch]
);
const handleReset = useCallback(
() => dispatch(setRefinerStart(0.7)),
[dispatch]
);
return (
<IAISlider
label="Refiner Start"
step={0.01}
min={0}
max={1}
onChange={handleChange}
handleReset={handleReset}
value={refinerStart}
withInput
withReset
withSliderMarks
isInteger={false}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerStart);

View File

@ -0,0 +1,72 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
[stateSelector],
({ sdxl, ui }) => {
const { refinerSteps } = sdxl;
const { shouldUseSliders } = ui;
return {
refinerSteps,
shouldUseSliders,
};
},
defaultSelectorOptions
);
const ParamSDXLRefinerSteps = () => {
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => {
dispatch(setRefinerSteps(v));
},
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(setRefinerSteps(20));
}, [dispatch]);
return shouldUseSliders ? (
<IAISlider
label={t('parameters.steps')}
min={1}
max={100}
step={1}
onChange={handleChange}
handleReset={handleReset}
value={refinerSteps}
withInput
withReset
withSliderMarks
sliderNumberInputProps={{ max: 500 }}
isDisabled={!isRefinerAvailable}
/>
) : (
<IAINumberInput
label={t('parameters.steps')}
min={1}
max={500}
step={1}
onChange={handleChange}
value={refinerSteps}
numberInputFieldProps={{ textAlign: 'center' }}
isDisabled={!isRefinerAvailable}
/>
);
};
export default memo(ParamSDXLRefinerSteps);

View File

@ -0,0 +1,28 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
import { ChangeEvent } from 'react';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
export default function ParamUseSDXLRefiner() {
const shouldUseSDXLRefiner = useAppSelector(
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
);
const isRefinerAvailable = useIsRefinerAvailable();
const dispatch = useAppDispatch();
const handleUseSDXLRefinerChange = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(setShouldUseSDXLRefiner(e.target.checked));
};
return (
<IAISwitch
label="Use Refiner"
isChecked={shouldUseSDXLRefiner}
onChange={handleUseSDXLRefinerChange}
isDisabled={!isRefinerAvailable}
/>
);
}

View File

@ -0,0 +1,27 @@
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
const SDXLTextToImageTabParameters = () => {
return (
<>
<ParamPositiveConditioning />
<ParamSDXLPositiveStyleConditioning />
<ParamNegativeConditioning />
<ParamSDXLNegativeStyleConditioning />
<ProcessButtons />
<TextToImageTabCoreParameters />
<ParamSDXLRefinerCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
</>
);
};
export default SDXLTextToImageTabParameters;

View File

@ -0,0 +1,89 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import {
MainModelParam,
NegativeStylePromptSDXLParam,
PositiveStylePromptSDXLParam,
SchedulerParam,
} from 'features/parameters/types/parameterSchemas';
import { MainModelField } from 'services/api/types';
type SDXLInitialState = {
positiveStylePrompt: PositiveStylePromptSDXLParam;
negativeStylePrompt: NegativeStylePromptSDXLParam;
shouldUseSDXLRefiner: boolean;
sdxlImg2ImgDenoisingStrength: number;
refinerModel: MainModelField | null;
refinerSteps: number;
refinerCFGScale: number;
refinerScheduler: SchedulerParam;
refinerAestheticScore: number;
refinerStart: number;
};
const sdxlInitialState: SDXLInitialState = {
positiveStylePrompt: '',
negativeStylePrompt: '',
shouldUseSDXLRefiner: false,
sdxlImg2ImgDenoisingStrength: 0.7,
refinerModel: null,
refinerSteps: 20,
refinerCFGScale: 7.5,
refinerScheduler: 'euler',
refinerAestheticScore: 6,
refinerStart: 0.7,
};
const sdxlSlice = createSlice({
name: 'sdxl',
initialState: sdxlInitialState,
reducers: {
setPositiveStylePromptSDXL: (state, action: PayloadAction<string>) => {
state.positiveStylePrompt = action.payload;
},
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
state.negativeStylePrompt = action.payload;
},
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
state.shouldUseSDXLRefiner = action.payload;
},
setSDXLImg2ImgDenoisingStrength: (state, action: PayloadAction<number>) => {
state.sdxlImg2ImgDenoisingStrength = action.payload;
},
refinerModelChanged: (
state,
action: PayloadAction<MainModelParam | null>
) => {
state.refinerModel = action.payload;
},
setRefinerSteps: (state, action: PayloadAction<number>) => {
state.refinerSteps = action.payload;
},
setRefinerCFGScale: (state, action: PayloadAction<number>) => {
state.refinerCFGScale = action.payload;
},
setRefinerScheduler: (state, action: PayloadAction<SchedulerParam>) => {
state.refinerScheduler = action.payload;
},
setRefinerAestheticScore: (state, action: PayloadAction<number>) => {
state.refinerAestheticScore = action.payload;
},
setRefinerStart: (state, action: PayloadAction<number>) => {
state.refinerStart = action.payload;
},
},
});
export const {
setPositiveStylePromptSDXL,
setNegativeStylePromptSDXL,
setShouldUseSDXLRefiner,
setSDXLImg2ImgDenoisingStrength,
refinerModelChanged,
setRefinerSteps,
setRefinerCFGScale,
setRefinerScheduler,
setRefinerAestheticScore,
setRefinerStart,
} = sdxlSlice.actions;
export default sdxlSlice.reducer;

View File

@ -16,7 +16,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent'; import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
import { configSelector } from 'features/system/store/configSelectors'; import { configSelector } from 'features/system/store/configSelectors';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice'; import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
import { ResourceKey } from 'i18next'; import { ResourceKey } from 'i18next';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -172,13 +172,22 @@ const InvokeTabs = () => {
const { ref: galleryPanelRef, minSizePct: galleryMinSizePct } = const { ref: galleryPanelRef, minSizePct: galleryMinSizePct } =
useMinimumPanelSize(MIN_GALLERY_WIDTH, DEFAULT_GALLERY_PCT, 'app'); useMinimumPanelSize(MIN_GALLERY_WIDTH, DEFAULT_GALLERY_PCT, 'app');
const handleTabChange = useCallback(
(index: number) => {
const activeTabName = tabMap[index];
if (!activeTabName) {
return;
}
dispatch(setActiveTab(activeTabName));
},
[dispatch]
);
return ( return (
<Tabs <Tabs
defaultIndex={activeTab} defaultIndex={activeTab}
index={activeTab} index={activeTab}
onChange={(index: number) => { onChange={handleTabChange}
dispatch(setActiveTab(index));
}}
sx={{ sx={{
flexGrow: 1, flexGrow: 1,
gap: 4, gap: 4,

View File

@ -1,7 +1,9 @@
import { Box, Flex } from '@chakra-ui/react'; import { Box, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay'; import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay';
import SDXLImageToImageTabParameters from 'features/sdxl/components/SDXLImageToImageTabParameters';
import { memo, useCallback, useRef } from 'react'; import { memo, useCallback, useRef } from 'react';
import { import {
ImperativePanelGroupHandle, ImperativePanelGroupHandle,
@ -16,6 +18,7 @@ import ImageToImageTabParameters from './ImageToImageTabParameters';
const ImageToImageTab = () => { const ImageToImageTab = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null); const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const model = useAppSelector((state: RootState) => state.generation.model);
const handleDoubleClickHandle = useCallback(() => { const handleDoubleClickHandle = useCallback(() => {
if (!panelGroupRef.current) { if (!panelGroupRef.current) {
@ -28,7 +31,11 @@ const ImageToImageTab = () => {
return ( return (
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}> <Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
<ParametersPinnedWrapper> <ParametersPinnedWrapper>
<ImageToImageTabParameters /> {model && model.base_model === 'sdxl' ? (
<SDXLImageToImageTabParameters />
) : (
<ImageToImageTabParameters />
)}
</ParametersPinnedWrapper> </ParametersPinnedWrapper>
<Box sx={{ w: 'full', h: 'full' }}> <Box sx={{ w: 'full', h: 'full' }}>
<PanelGroup <PanelGroup

View File

@ -16,6 +16,7 @@ import {
useImportMainModelsMutation, useImportMainModelsMutation,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice'; import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
import { ALL_BASE_MODELS } from 'services/api/constants';
export default function FoundModelsList() { export default function FoundModelsList() {
const searchFolder = useAppSelector( const searchFolder = useAppSelector(
@ -24,7 +25,7 @@ export default function FoundModelsList() {
const [nameFilter, setNameFilter] = useState<string>(''); const [nameFilter, setNameFilter] = useState<string>('');
// Get paths of models that are already installed // Get paths of models that are already installed
const { data: installedModels } = useGetMainModelsQuery(); const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
// Get all model paths from a given directory // Get all model paths from a given directory
const { foundModels, alreadyInstalled, filteredModels } = const { foundModels, alreadyInstalled, filteredModels } =

View File

@ -1,5 +1,4 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { pickBy } from 'lodash-es'; import { pickBy } from 'lodash-es';
import { useMemo, useState } from 'react'; import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import { import {
useGetMainModelsQuery, useGetMainModelsQuery,
useMergeMainModelsMutation, useMergeMainModelsMutation,
@ -32,7 +33,7 @@ export default function MergeModelsPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery(); const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
const [mergeModels, { isLoading }] = useMergeMainModelsMutation(); const [mergeModels, { isLoading }] = useMergeMainModelsMutation();

View File

@ -8,10 +8,11 @@ import {
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
import { ALL_BASE_MODELS } from 'services/api/constants';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const [selectedModelId, setSelectedModelId] = useState<string>(); const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, { const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined, model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}), }),

View File

@ -11,6 +11,7 @@ import {
useGetMainModelsQuery, useGetMainModelsQuery,
} from 'services/api/endpoints/models'; } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { ALL_BASE_MODELS } from 'services/api/constants';
type ModelListProps = { type ModelListProps = {
selectedModelId: string | undefined; selectedModelId: string | undefined;
@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => {
const [modelFormatFilter, setModelFormatFilter] = const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('images'); useState<ModelFormat>('images');
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, { const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter), filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
}), }),
}); });
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, { const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
selectFromResult: ({ data }) => ({ selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter), filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
}), }),

View File

@ -1,14 +1,22 @@
import { Flex } from '@chakra-ui/react'; import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import TextToImageSDXLTabParameters from 'features/sdxl/components/SDXLTextToImageTabParameters';
import { memo } from 'react'; import { memo } from 'react';
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper'; import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
import TextToImageTabMain from './TextToImageTabMain'; import TextToImageTabMain from './TextToImageTabMain';
import TextToImageTabParameters from './TextToImageTabParameters'; import TextToImageTabParameters from './TextToImageTabParameters';
const TextToImageTab = () => { const TextToImageTab = () => {
const model = useAppSelector((state: RootState) => state.generation.model);
return ( return (
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}> <Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
<ParametersPinnedWrapper> <ParametersPinnedWrapper>
<TextToImageTabParameters /> {model && model.base_model === 'sdxl' ? (
<TextToImageSDXLTabParameters />
) : (
<TextToImageTabParameters />
)}
</ParametersPinnedWrapper> </ParametersPinnedWrapper>
<TextToImageTabMain /> <TextToImageTabMain />
</Flex> </Flex>

View File

@ -26,7 +26,7 @@ export const uiSlice = createSlice({
name: 'ui', name: 'ui',
initialState: initialUIState, initialState: initialUIState,
reducers: { reducers: {
setActiveTab: (state, action: PayloadAction<number | InvokeTabName>) => { setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
setActiveTabReducer(state, action.payload); setActiveTabReducer(state, action.payload);
}, },
setShouldPinParametersPanel: (state, action: PayloadAction<boolean>) => { setShouldPinParametersPanel: (state, action: PayloadAction<boolean>) => {

View File

@ -0,0 +1,16 @@
import { BaseModelType } from './types';
export const ALL_BASE_MODELS: BaseModelType[] = [
'sd-1',
'sd-2',
'sdxl',
'sdxl-refiner',
];
export const NON_REFINER_BASE_MODELS: BaseModelType[] = [
'sd-1',
'sd-2',
'sdxl',
];
export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner'];

View File

@ -144,8 +144,19 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
export const modelsApi = api.injectEndpoints({ export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({ getMainModels: build.query<
query: () => ({ url: 'models/', params: { model_type: 'main' } }), EntityState<MainModelConfigEntity>,
BaseModelType[]
>({
query: (base_models) => {
const params = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ type: 'MainModel', id: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
@ -187,7 +198,10 @@ export const modelsApi = api.injectEndpoints({
body: body, body: body,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
importMainModels: build.mutation< importMainModels: build.mutation<
ImportMainModelResponse, ImportMainModelResponse,
@ -200,7 +214,10 @@ export const modelsApi = api.injectEndpoints({
body: body, body: body,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({ addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
query: ({ body }) => { query: ({ body }) => {
@ -210,7 +227,10 @@ export const modelsApi = api.injectEndpoints({
body: body, body: body,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
deleteMainModels: build.mutation< deleteMainModels: build.mutation<
DeleteMainModelResponse, DeleteMainModelResponse,
@ -222,7 +242,10 @@ export const modelsApi = api.injectEndpoints({
method: 'DELETE', method: 'DELETE',
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
convertMainModels: build.mutation< convertMainModels: build.mutation<
ConvertMainModelResponse, ConvertMainModelResponse,
@ -235,7 +258,10 @@ export const modelsApi = api.injectEndpoints({
params: params, params: params,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({ mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => { query: ({ base_model, body }) => {
@ -245,7 +271,10 @@ export const modelsApi = api.injectEndpoints({
body: body, body: body,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
syncModels: build.mutation<SyncModelsResponse, void>({ syncModels: build.mutation<SyncModelsResponse, void>({
query: () => { query: () => {
@ -254,7 +283,10 @@ export const modelsApi = api.injectEndpoints({
method: 'POST', method: 'POST',
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [
{ type: 'MainModel', id: LIST_TAG },
{ type: 'SDXLRefinerModel', id: LIST_TAG },
],
}), }),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }), query: () => ({ url: 'models/', params: { model_type: 'lora' } }),

View File

@ -0,0 +1,12 @@
import { REFINER_BASE_MODELS } from 'services/api/constants';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
export const useIsRefinerAvailable = () => {
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
selectFromResult: ({ data }) => ({
isRefinerAvailable: data ? data.ids.length > 0 : false,
}),
});
return isRefinerAvailable;
};

View File

@ -1014,6 +1014,11 @@ export type components = {
* @description The LoRAs used for inference * @description The LoRAs used for inference
*/ */
loras: (components["schemas"]["LoRAMetadataField"])[]; loras: (components["schemas"]["LoRAMetadataField"])[];
/**
* Vae
* @description The VAE used for decoding, if the main model's default was not used
*/
vae?: components["schemas"]["VAEModelField"];
/** /**
* Strength * Strength
* @description The strength used for latents-to-latents * @description The strength used for latents-to-latents
@ -1025,10 +1030,45 @@ export type components = {
*/ */
init_image?: string; init_image?: string;
/** /**
* Vae * Positive Style Prompt
* @description The VAE used for decoding, if the main model's default was not used * @description The positive style prompt parameter
*/ */
vae?: components["schemas"]["VAEModelField"]; positive_style_prompt?: string;
/**
* Negative Style Prompt
* @description The negative style prompt parameter
*/
negative_style_prompt?: string;
/**
* Refiner Model
* @description The SDXL Refiner model used
*/
refiner_model?: components["schemas"]["MainModelField"];
/**
* Refiner Cfg Scale
* @description The classifier-free guidance scale parameter used for the refiner
*/
refiner_cfg_scale?: number;
/**
* Refiner Steps
* @description The number of steps used for the refiner
*/
refiner_steps?: number;
/**
* Refiner Scheduler
* @description The scheduler used for the refiner
*/
refiner_scheduler?: string;
/**
* Refiner Aesthetic Store
* @description The aesthetic score used for the refiner
*/
refiner_aesthetic_store?: number;
/**
* Refiner Start
* @description The start value used for refiner denoising
*/
refiner_start?: number;
}; };
/** /**
* CvInpaintInvocation * CvInpaintInvocation
@ -3268,6 +3308,46 @@ export type components = {
* @description The VAE used for decoding, if the main model's default was not used * @description The VAE used for decoding, if the main model's default was not used
*/ */
vae?: components["schemas"]["VAEModelField"]; vae?: components["schemas"]["VAEModelField"];
/**
* Positive Style Prompt
* @description The positive style prompt parameter
*/
positive_style_prompt?: string;
/**
* Negative Style Prompt
* @description The negative style prompt parameter
*/
negative_style_prompt?: string;
/**
* Refiner Model
* @description The SDXL Refiner model used
*/
refiner_model?: components["schemas"]["MainModelField"];
/**
* Refiner Cfg Scale
* @description The classifier-free guidance scale parameter used for the refiner
*/
refiner_cfg_scale?: number;
/**
* Refiner Steps
* @description The number of steps used for the refiner
*/
refiner_steps?: number;
/**
* Refiner Scheduler
* @description The scheduler used for the refiner
*/
refiner_scheduler?: string;
/**
* Refiner Aesthetic Store
* @description The aesthetic score used for the refiner
*/
refiner_aesthetic_store?: number;
/**
* Refiner Start
* @description The start value used for refiner denoising
*/
refiner_start?: number;
}; };
/** /**
* MetadataAccumulatorOutput * MetadataAccumulatorOutput
@ -5355,6 +5435,12 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion2ModelFormat * StableDiffusion2ModelFormat
* @description An enumeration. * @description An enumeration.
@ -5367,12 +5453,6 @@ export type components = {
* @enum {string} * @enum {string}
*/ */
StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;