mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add SDXL To Linear UI (#3973)
## What type of PR is this? (check all applicable) - [x] Feature ## Have you discussed this change with the InvokeAI team? - [x] Yes ## Description This PR adds support for SDXL Models in the Linear UI ### DONE - SDXL Base Text To Image Support - SDXL Base Image To Image Support - SDXL Refiner Support - SDXL Relevant UI ## [optional] Are there any post deployment tasks we need to perform? Double check to ensure nothing major changed with 1.0 -- In any case those changes would be backend related mostly. If Refiner is scrapped for 1.0 models, then we simply disable the Refiner Graph.
This commit is contained in:
commit
531bc40d3f
@ -95,7 +95,7 @@ class CompelInvocation(BaseInvocation):
|
||||
def _lora_loader():
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -171,16 +171,16 @@ class CompelInvocation(BaseInvocation):
|
||||
class SDXLPromptInvocationBase:
|
||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
**clip_field.tokenizer.dict(), context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
**clip_field.text_encoder.dict(), context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -196,6 +196,7 @@ class SDXLPromptInvocationBase:
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
@ -240,16 +241,16 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
**clip_field.tokenizer.dict(),
|
||||
**clip_field.tokenizer.dict(), context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
**clip_field.text_encoder.dict(),
|
||||
**clip_field.text_encoder.dict(), context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
**lora.dict(exclude={"weight"}))
|
||||
**lora.dict(exclude={"weight"}), context=context)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
@ -265,6 +266,7 @@ class SDXLPromptInvocationBase:
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
|
@ -2,16 +2,19 @@ from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
|
||||
BaseInvocationOutput, InvocationConfig,
|
||||
InvocationContext)
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
InvocationContext,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
|
||||
VAEModelField)
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
|
||||
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA metadata for an image generated in InvokeAI."""
|
||||
|
||||
lora: LoRAModelField = Field(description="The LoRA model")
|
||||
weight: float = Field(description="The weight of the LoRA model")
|
||||
|
||||
@ -19,7 +22,9 @@ class LoRAMetadataField(BaseModel):
|
||||
class CoreMetadata(BaseModel):
|
||||
"""Core generation metadata for an image generated in InvokeAI."""
|
||||
|
||||
generation_mode: str = Field(description="The generation mode that output this image",)
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
@ -29,10 +34,20 @@ class CoreMetadata(BaseModel):
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
||||
clip_skip: int = Field(
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
||||
controlnets: list[ControlField] = Field(
|
||||
description="The ControlNets used for inference"
|
||||
)
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# Latents-to-Latents
|
||||
strength: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The strength used for latents-to-latents",
|
||||
@ -40,9 +55,34 @@ class CoreMetadata(BaseModel):
|
||||
init_image: Union[str, None] = Field(
|
||||
default=None, description="The name of the initial image"
|
||||
)
|
||||
vae: Union[VAEModelField, None] = Field(
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(
|
||||
default=None, description="The positive style prompt parameter"
|
||||
)
|
||||
negative_style_prompt: Union[str, None] = Field(
|
||||
default=None, description="The negative style prompt parameter"
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(
|
||||
default=None, description="The SDXL Refiner model used"
|
||||
)
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(
|
||||
default=None, description="The number of steps used for the refiner"
|
||||
)
|
||||
refiner_scheduler: Union[str, None] = Field(
|
||||
default=None, description="The scheduler used for the refiner"
|
||||
)
|
||||
refiner_aesthetic_store: Union[float, None] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(
|
||||
default=None, description="The start value used for refiner denoising"
|
||||
)
|
||||
|
||||
|
||||
@ -71,7 +111,9 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
|
||||
type: Literal["metadata_accumulator"] = "metadata_accumulator"
|
||||
|
||||
generation_mode: str = Field(description="The generation mode that output this image",)
|
||||
generation_mode: str = Field(
|
||||
description="The generation mode that output this image",
|
||||
)
|
||||
positive_prompt: str = Field(description="The positive prompt parameter")
|
||||
negative_prompt: str = Field(description="The negative prompt parameter")
|
||||
width: int = Field(description="The width parameter")
|
||||
@ -81,9 +123,13 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
|
||||
steps: int = Field(description="The number of steps used for inference")
|
||||
scheduler: str = Field(description="The scheduler used for inference")
|
||||
clip_skip: int = Field(description="The number of skipped CLIP layers",)
|
||||
clip_skip: int = Field(
|
||||
description="The number of skipped CLIP layers",
|
||||
)
|
||||
model: MainModelField = Field(description="The main model used for inference")
|
||||
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
|
||||
controlnets: list[ControlField] = Field(
|
||||
description="The ControlNets used for inference"
|
||||
)
|
||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||
strength: Union[float, None] = Field(
|
||||
default=None,
|
||||
@ -97,36 +143,44 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
||||
description="The VAE used for decoding, if the main model's default was not used",
|
||||
)
|
||||
|
||||
# SDXL
|
||||
positive_style_prompt: Union[str, None] = Field(
|
||||
default=None, description="The positive style prompt parameter"
|
||||
)
|
||||
negative_style_prompt: Union[str, None] = Field(
|
||||
default=None, description="The negative style prompt parameter"
|
||||
)
|
||||
|
||||
# SDXL Refiner
|
||||
refiner_model: Union[MainModelField, None] = Field(
|
||||
default=None, description="The SDXL Refiner model used"
|
||||
)
|
||||
refiner_cfg_scale: Union[float, None] = Field(
|
||||
default=None,
|
||||
description="The classifier-free guidance scale parameter used for the refiner",
|
||||
)
|
||||
refiner_steps: Union[int, None] = Field(
|
||||
default=None, description="The number of steps used for the refiner"
|
||||
)
|
||||
refiner_scheduler: Union[str, None] = Field(
|
||||
default=None, description="The scheduler used for the refiner"
|
||||
)
|
||||
refiner_aesthetic_store: Union[float, None] = Field(
|
||||
default=None, description="The aesthetic score used for the refiner"
|
||||
)
|
||||
refiner_start: Union[float, None] = Field(
|
||||
default=None, description="The start value used for refiner denoising"
|
||||
)
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"title": "Metadata Accumulator",
|
||||
"tags": ["image", "metadata", "generation"]
|
||||
"tags": ["image", "metadata", "generation"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
|
||||
"""Collects and outputs a CoreMetadata object"""
|
||||
|
||||
return MetadataAccumulatorOutput(
|
||||
metadata=CoreMetadata(
|
||||
generation_mode=self.generation_mode,
|
||||
positive_prompt=self.positive_prompt,
|
||||
negative_prompt=self.negative_prompt,
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
seed=self.seed,
|
||||
rand_device=self.rand_device,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
scheduler=self.scheduler,
|
||||
model=self.model,
|
||||
strength=self.strength,
|
||||
init_image=self.init_image,
|
||||
vae=self.vae,
|
||||
controlnets=self.controlnets,
|
||||
loras=self.loras,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
)
|
||||
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
|
||||
|
@ -138,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"ui": {
|
||||
"title": "SDXL Refiner Model Loader",
|
||||
"tags": ["model", "loader", "sdxl_refiner"],
|
||||
"type_hints": {"model": "model"},
|
||||
"type_hints": {"model": "refiner_model"},
|
||||
},
|
||||
}
|
||||
|
||||
@ -295,7 +295,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
**self.unet.unet.dict(), context=context
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
@ -463,8 +463,8 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
latents: Optional[LatentsField] = Field(description="Initial latents")
|
||||
|
||||
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
|
||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
@ -549,13 +549,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||
num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# apply noise(if provided)
|
||||
if self.noise is not None:
|
||||
if self.noise is not None and timesteps.shape[0] > 0:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
||||
del noise
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
**self.unet.unet.dict()
|
||||
**self.unet.unet.dict(), context=context,
|
||||
)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
|
@ -65,18 +65,19 @@ import { addGeneratorProgressEventListener as addGeneratorProgressListener } fro
|
||||
import { addGraphExecutionStateCompleteEventListener as addGraphExecutionStateCompleteListener } from './listeners/socketio/socketGraphExecutionStateComplete';
|
||||
import { addInvocationCompleteEventListener as addInvocationCompleteListener } from './listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener as addInvocationErrorListener } from './listeners/socketio/socketInvocationError';
|
||||
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
|
||||
import { addInvocationStartedEventListener as addInvocationStartedListener } from './listeners/socketio/socketInvocationStarted';
|
||||
import { addModelLoadEventListener } from './listeners/socketio/socketModelLoad';
|
||||
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
|
||||
import { addSocketSubscribedEventListener as addSocketSubscribedListener } from './listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } from './listeners/socketio/socketUnsubscribed';
|
||||
import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved';
|
||||
import { addTabChangedListener } from './listeners/tabChanged';
|
||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
|
||||
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -201,3 +202,6 @@ addFirstListImagesListener();
|
||||
|
||||
// Ad-hoc upscale workflwo
|
||||
addUpscaleRequestedListener();
|
||||
|
||||
// Tab Change
|
||||
addTabChangedListener();
|
||||
|
@ -9,13 +9,19 @@ import {
|
||||
zMainModel,
|
||||
zVaeModel,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setShouldUseSDXLRefiner,
|
||||
} from 'features/sdxl/store/sdxlSlice';
|
||||
import { forEach, some } from 'lodash-es';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addModelsLoadedListener = () => {
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getMainModels.matchFulfilled,
|
||||
predicate: (state, action) =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
!action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
@ -59,6 +65,54 @@ export const addModelsLoadedListener = () => {
|
||||
dispatch(modelChanged(result.data));
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
predicate: (state, action) =>
|
||||
modelsApi.endpoints.getMainModels.matchFulfilled(action) &&
|
||||
action.meta.arg.originalArgs.includes('sdxl-refiner'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
// models loaded, we need to ensure the selected model is available and if not, select the first one
|
||||
const log = logger('models');
|
||||
log.info(
|
||||
{ models: action.payload.entities },
|
||||
`SDXL Refiner models loaded (${action.payload.ids.length})`
|
||||
);
|
||||
|
||||
const currentModel = getState().sdxl.refinerModel;
|
||||
|
||||
const isCurrentModelAvailable = some(
|
||||
action.payload.entities,
|
||||
(m) =>
|
||||
m?.model_name === currentModel?.model_name &&
|
||||
m?.base_model === currentModel?.base_model
|
||||
);
|
||||
|
||||
if (isCurrentModelAvailable) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstModelId = action.payload.ids[0];
|
||||
const firstModel = action.payload.entities[firstModelId];
|
||||
|
||||
if (!firstModel) {
|
||||
// No models loaded at all
|
||||
dispatch(refinerModelChanged(null));
|
||||
dispatch(setShouldUseSDXLRefiner(false));
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zMainModel.safeParse(firstModel);
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{ error: result.error.format() },
|
||||
'Failed to parse SDXL Refiner Model'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(refinerModelChanged(result.data));
|
||||
},
|
||||
});
|
||||
startAppListening({
|
||||
matcher: modelsApi.endpoints.getVaeModels.matchFulfilled,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
|
@ -3,6 +3,11 @@ import { modelsApi } from 'services/api/endpoints/models';
|
||||
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
|
||||
import { appSocketConnected, socketConnected } from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
import {
|
||||
ALL_BASE_MODELS,
|
||||
NON_REFINER_BASE_MODELS,
|
||||
REFINER_BASE_MODELS,
|
||||
} from 'services/api/constants';
|
||||
|
||||
export const addSocketConnectedEventListener = () => {
|
||||
startAppListening({
|
||||
@ -24,7 +29,11 @@ export const addSocketConnectedEventListener = () => {
|
||||
dispatch(appSocketConnected(action.payload));
|
||||
|
||||
// update all server state
|
||||
dispatch(modelsApi.endpoints.getMainModels.initiate());
|
||||
dispatch(modelsApi.endpoints.getMainModels.initiate(REFINER_BASE_MODELS));
|
||||
dispatch(
|
||||
modelsApi.endpoints.getMainModels.initiate(NON_REFINER_BASE_MODELS)
|
||||
);
|
||||
dispatch(modelsApi.endpoints.getMainModels.initiate(ALL_BASE_MODELS));
|
||||
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
|
||||
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
|
||||
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
|
||||
|
@ -21,7 +21,10 @@ export const addInvocationStartedEventListener = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
log.debug(action.payload, 'Invocation started');
|
||||
log.debug(
|
||||
action.payload,
|
||||
`Invocation started (${action.payload.data.node.type})`
|
||||
);
|
||||
dispatch(appSocketInvocationStarted(action.payload));
|
||||
},
|
||||
});
|
||||
|
@ -0,0 +1,56 @@
|
||||
import { modelChanged } from 'features/parameters/store/generationSlice';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
modelsApi,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addTabChangedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: setActiveTab,
|
||||
effect: (action, { getState, dispatch }) => {
|
||||
const activeTabName = action.payload;
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
// grab the models from RTK Query cache
|
||||
const { data } = modelsApi.endpoints.getMainModels.select(
|
||||
NON_REFINER_BASE_MODELS
|
||||
)(getState());
|
||||
|
||||
if (!data) {
|
||||
// no models yet, so we can't do anything
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
// need to filter out all the invalid canvas models (currently, this is just sdxl)
|
||||
const validCanvasModels: MainModelConfigEntity[] = [];
|
||||
|
||||
forEach(data.entities, (entity) => {
|
||||
if (!entity) {
|
||||
return;
|
||||
}
|
||||
if (['sd-1', 'sd-2'].includes(entity.base_model)) {
|
||||
validCanvasModels.push(entity);
|
||||
}
|
||||
});
|
||||
|
||||
// this could still be undefined even tho TS doesn't say so
|
||||
const firstValidCanvasModel = validCanvasModels[0];
|
||||
|
||||
if (!firstValidCanvasModel) {
|
||||
// uh oh, we have no models that are valid for canvas
|
||||
dispatch(modelChanged(null));
|
||||
return;
|
||||
}
|
||||
|
||||
// only store the model name and base model in redux
|
||||
const { base_model, model_name } = firstValidCanvasModel;
|
||||
|
||||
dispatch(modelChanged({ base_model, model_name }));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -3,6 +3,7 @@ import { userInvoked } from 'app/store/actions';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
@ -14,8 +15,16 @@ export const addUserInvokedImageToImageListener = () => {
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const log = logger('session');
|
||||
const state = getState();
|
||||
const model = state.generation.model;
|
||||
|
||||
let graph;
|
||||
|
||||
if (model && model.base_model === 'sdxl') {
|
||||
graph = buildLinearSDXLImageToImageGraph(state);
|
||||
} else {
|
||||
graph = buildLinearImageToImageGraph(state);
|
||||
}
|
||||
|
||||
const graph = buildLinearImageToImageGraph(state);
|
||||
dispatch(imageToImageGraphBuilt(graph));
|
||||
log.debug({ graph: parseify(graph) }, 'Image to Image graph built');
|
||||
|
||||
|
@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
|
||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||
import { sessionReadyToInvoke } from 'features/system/store/actions';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
@ -14,8 +15,15 @@ export const addUserInvokedTextToImageListener = () => {
|
||||
effect: async (action, { getState, dispatch, take }) => {
|
||||
const log = logger('session');
|
||||
const state = getState();
|
||||
const model = state.generation.model;
|
||||
|
||||
const graph = buildLinearTextToImageGraph(state);
|
||||
let graph;
|
||||
|
||||
if (model && model.base_model === 'sdxl') {
|
||||
graph = buildLinearSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildLinearTextToImageGraph(state);
|
||||
}
|
||||
|
||||
dispatch(textToImageGraphBuilt(graph));
|
||||
|
||||
|
@ -15,6 +15,7 @@ import loraReducer from 'features/lora/store/loraSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import sdxlReducer from 'features/sdxl/store/sdxlSlice';
|
||||
import configReducer from 'features/system/store/configSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import modelmanagerReducer from 'features/ui/components/tabs/ModelManager/store/modelManagerSlice';
|
||||
@ -47,6 +48,7 @@ const allReducers = {
|
||||
imageDeletion: imageDeletionReducer,
|
||||
lora: loraReducer,
|
||||
modelmanager: modelmanagerReducer,
|
||||
sdxl: sdxlReducer,
|
||||
[api.reducerPath]: api.reducer,
|
||||
};
|
||||
|
||||
@ -58,6 +60,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
|
||||
'canvas',
|
||||
'gallery',
|
||||
'generation',
|
||||
'sdxl',
|
||||
'nodes',
|
||||
'postprocessing',
|
||||
'system',
|
||||
|
@ -6,6 +6,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { modelsApi } from '../../services/api/endpoints/models';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
const readinessSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
@ -24,7 +25,7 @@ const readinessSelector = createSelector(
|
||||
}
|
||||
|
||||
const { isSuccess: mainModelsSuccessfullyLoaded } =
|
||||
modelsApi.endpoints.getMainModels.select()(state);
|
||||
modelsApi.endpoints.getMainModels.select(ALL_BASE_MODELS)(state);
|
||||
if (!mainModelsSuccessfullyLoaded) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('Models are not loaded');
|
||||
|
@ -20,6 +20,7 @@ import StringInputFieldComponent from './fields/StringInputFieldComponent';
|
||||
import UnetInputFieldComponent from './fields/UnetInputFieldComponent';
|
||||
import VaeInputFieldComponent from './fields/VaeInputFieldComponent';
|
||||
import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent';
|
||||
import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent';
|
||||
|
||||
type InputFieldComponentProps = {
|
||||
nodeId: string;
|
||||
@ -155,6 +156,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'refiner_model' && template.type === 'refiner_model') {
|
||||
return (
|
||||
<RefinerModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'vae_model' && template.type === 'vae_model') {
|
||||
return (
|
||||
<VaeModelInputFieldComponent
|
||||
|
@ -14,6 +14,7 @@ import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { useFeatureStatus } from '../../../system/hooks/useFeatureStatus';
|
||||
@ -27,7 +28,9 @@ const ModelInputFieldComponent = (
|
||||
const { t } = useTranslation();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||
NON_REFINER_BASE_MODELS
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
|
@ -0,0 +1,120 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
RefinerModelInputFieldTemplate,
|
||||
RefinerModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
|
||||
const RefinerModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
RefinerModelInputFieldValue,
|
||||
RefinerModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
const { data: refinerModels, isLoading } =
|
||||
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!refinerModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(refinerModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [refinerModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
refinerModels?.entities[
|
||||
`${field.value?.base_model}/main/${field.value?.model_name}`
|
||||
] ?? null,
|
||||
[field.value?.base_model, field.value?.model_name, refinerModels?.entities]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newModel = modelIdToMainModelParam(v);
|
||||
|
||||
if (!newModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: newModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSearchableSelect
|
||||
label={t('modelManager.model')}
|
||||
placeholder="Loading..."
|
||||
disabled={true}
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<Flex w="100%" alignItems="center" gap={2}>
|
||||
<IAIMantineSearchableSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
/>
|
||||
{isSyncModelEnabled && (
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(RefinerModelInputFieldComponent);
|
@ -17,6 +17,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
ClipField: 'clip',
|
||||
VaeField: 'vae',
|
||||
model: 'model',
|
||||
refiner_model: 'refiner_model',
|
||||
vae_model: 'vae_model',
|
||||
lora_model: 'lora_model',
|
||||
controlnet_model: 'controlnet_model',
|
||||
@ -120,6 +121,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
refiner_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'Refiner Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
vae_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
|
@ -70,6 +70,7 @@ export type FieldType =
|
||||
| 'vae'
|
||||
| 'control'
|
||||
| 'model'
|
||||
| 'refiner_model'
|
||||
| 'vae_model'
|
||||
| 'lora_model'
|
||||
| 'controlnet_model'
|
||||
@ -100,6 +101,7 @@ export type InputFieldValue =
|
||||
| ControlInputFieldValue
|
||||
| EnumInputFieldValue
|
||||
| MainModelInputFieldValue
|
||||
| RefinerModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| LoRAModelInputFieldValue
|
||||
| ControlNetModelInputFieldValue
|
||||
@ -128,6 +130,7 @@ export type InputFieldTemplate =
|
||||
| ControlInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| ModelInputFieldTemplate
|
||||
| RefinerModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
@ -243,6 +246,11 @@ export type MainModelInputFieldValue = FieldValueBase & {
|
||||
value?: MainModelParam;
|
||||
};
|
||||
|
||||
export type RefinerModelInputFieldValue = FieldValueBase & {
|
||||
type: 'refiner_model';
|
||||
value?: MainModelParam;
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldValue = FieldValueBase & {
|
||||
type: 'vae_model';
|
||||
value?: VaeModelParam;
|
||||
@ -367,6 +375,11 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'model';
|
||||
};
|
||||
|
||||
export type RefinerModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'refiner_model';
|
||||
};
|
||||
|
||||
export type VaeModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'vae_model';
|
||||
|
@ -22,6 +22,7 @@ import {
|
||||
LoRAModelInputFieldTemplate,
|
||||
ModelInputFieldTemplate,
|
||||
OutputFieldTemplate,
|
||||
RefinerModelInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
TypeHints,
|
||||
UNetInputFieldTemplate,
|
||||
@ -178,6 +179,21 @@ const buildModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildRefinerModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): RefinerModelInputFieldTemplate => {
|
||||
const template: RefinerModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'refiner_model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildVaeModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -492,6 +508,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['model'].includes(fieldType)) {
|
||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['refiner_model'].includes(fieldType)) {
|
||||
return buildRefinerModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['vae_model'].includes(fieldType)) {
|
||||
return buildVaeModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
@ -76,6 +76,10 @@ export const buildInputFieldValue = (
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'refiner_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'vae_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
@ -0,0 +1,186 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { MetadataAccumulatorInvocation } from 'services/api/types';
|
||||
import { NonNullableGraph } from '../../types/types';
|
||||
import {
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
SDXL_LATENTS_TO_LATENTS,
|
||||
SDXL_MODEL_LOADER,
|
||||
SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
SDXL_REFINER_MODEL_LOADER,
|
||||
SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
} from './constants';
|
||||
|
||||
export const addSDXLRefinerToGraph = (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): void => {
|
||||
const { positivePrompt, negativePrompt } = state.generation;
|
||||
const {
|
||||
refinerModel,
|
||||
refinerAestheticScore,
|
||||
positiveStylePrompt,
|
||||
negativeStylePrompt,
|
||||
refinerSteps,
|
||||
refinerScheduler,
|
||||
refinerCFGScale,
|
||||
refinerStart,
|
||||
} = state.sdxl;
|
||||
|
||||
if (!refinerModel) return;
|
||||
|
||||
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||
| MetadataAccumulatorInvocation
|
||||
| undefined;
|
||||
|
||||
if (metadataAccumulator) {
|
||||
metadataAccumulator.refiner_model = refinerModel;
|
||||
metadataAccumulator.refiner_aesthetic_store = refinerAestheticScore;
|
||||
metadataAccumulator.refiner_cfg_scale = refinerCFGScale;
|
||||
metadataAccumulator.refiner_scheduler = refinerScheduler;
|
||||
metadataAccumulator.refiner_start = refinerStart;
|
||||
metadataAccumulator.refiner_steps = refinerSteps;
|
||||
}
|
||||
|
||||
// Unplug SDXL Latents Generation To Latents To Image
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field))
|
||||
);
|
||||
|
||||
graph.edges = graph.edges.filter(
|
||||
(e) =>
|
||||
!(
|
||||
e.source.node_id === SDXL_MODEL_LOADER &&
|
||||
['vae'].includes(e.source.field)
|
||||
)
|
||||
);
|
||||
|
||||
// connect the VAE back to the i2l, which we just removed in the filter
|
||||
// but only if we are doing l2l
|
||||
if (baseNodeId === SDXL_LATENTS_TO_LATENTS) {
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'vae',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
|
||||
type: 'sdxl_refiner_model_loader',
|
||||
id: SDXL_REFINER_MODEL_LOADER,
|
||||
model: refinerModel,
|
||||
};
|
||||
graph.nodes[SDXL_REFINER_POSITIVE_CONDITIONING] = {
|
||||
type: 'sdxl_refiner_compel_prompt',
|
||||
id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
style: `${positivePrompt} ${positiveStylePrompt}`,
|
||||
aesthetic_score: refinerAestheticScore,
|
||||
};
|
||||
graph.nodes[SDXL_REFINER_NEGATIVE_CONDITIONING] = {
|
||||
type: 'sdxl_refiner_compel_prompt',
|
||||
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
style: `${negativePrompt} ${negativeStylePrompt}`,
|
||||
aesthetic_score: refinerAestheticScore,
|
||||
};
|
||||
graph.nodes[SDXL_REFINER_LATENTS_TO_LATENTS] = {
|
||||
type: 'l2l_sdxl',
|
||||
id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
cfg_scale: refinerCFGScale,
|
||||
steps: refinerSteps / (1 - Math.min(refinerStart, 0.99)),
|
||||
scheduler: refinerScheduler,
|
||||
denoising_start: refinerStart,
|
||||
denoising_end: 1,
|
||||
};
|
||||
|
||||
graph.edges.push(
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: baseNodeId,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
}
|
||||
);
|
||||
};
|
@ -46,6 +46,7 @@ export const buildLinearImageToImageGraph = (
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
shouldUseNoiseSettings,
|
||||
vaePrecision,
|
||||
} = state.generation;
|
||||
|
||||
// TODO: add batch functionality
|
||||
@ -113,6 +114,7 @@ export const buildLinearImageToImageGraph = (
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[LATENTS_TO_LATENTS]: {
|
||||
type: 'l2l',
|
||||
@ -129,6 +131,7 @@ export const buildLinearImageToImageGraph = (
|
||||
// image: {
|
||||
// image_name: initialImage.image_name,
|
||||
// },
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -0,0 +1,369 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
ImageResizeInvocation,
|
||||
ImageToLatentsInvocation,
|
||||
} from 'services/api/types';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import {
|
||||
IMAGE_TO_LATENTS,
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
RESIZE,
|
||||
SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||
SDXL_LATENTS_TO_LATENTS,
|
||||
SDXL_MODEL_LOADER,
|
||||
} from './constants';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildLinearSDXLImageToImageGraph = (
|
||||
state: RootState
|
||||
): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
initialImage,
|
||||
shouldFitToWidthHeight,
|
||||
width,
|
||||
height,
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
shouldUseNoiseSettings,
|
||||
vaePrecision,
|
||||
} = state.generation;
|
||||
|
||||
const {
|
||||
positiveStylePrompt,
|
||||
negativeStylePrompt,
|
||||
shouldUseSDXLRefiner,
|
||||
refinerStart,
|
||||
sdxlImg2ImgDenoisingStrength: strength,
|
||||
} = state.sdxl;
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
if (!initialImage) {
|
||||
log.error('No initial image found in state');
|
||||
throw new Error('No initial image found in state');
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
: initialGenerationState.shouldUseCpuNoise;
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: SDXL_IMAGE_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[SDXL_MODEL_LOADER]: {
|
||||
type: 'sdxl_model_loader',
|
||||
id: SDXL_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
style: positiveStylePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
style: negativeStylePrompt,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
use_cpu,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
[SDXL_LATENTS_TO_LATENTS]: {
|
||||
type: 'l2l_sdxl',
|
||||
id: SDXL_LATENTS_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_start: shouldUseSDXLRefiner
|
||||
? Math.min(refinerStart, 1 - strength)
|
||||
: 1 - strength,
|
||||
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
|
||||
},
|
||||
[IMAGE_TO_LATENTS]: {
|
||||
type: 'i2l',
|
||||
id: IMAGE_TO_LATENTS,
|
||||
// must be set manually later, bc `fit` parameter may require a resize node inserted
|
||||
// image: {
|
||||
// image_name: initialImage.image_name,
|
||||
// },
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_LATENTS_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// handle `fit`
|
||||
if (
|
||||
shouldFitToWidthHeight &&
|
||||
(initialImage.width !== width || initialImage.height !== height)
|
||||
) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
||||
// Create a resize node, explicitly setting its image
|
||||
const resizeNode: ImageResizeInvocation = {
|
||||
id: RESIZE,
|
||||
type: 'img_resize',
|
||||
image: {
|
||||
image_name: initialImage.imageName,
|
||||
},
|
||||
is_intermediate: true,
|
||||
width,
|
||||
height,
|
||||
};
|
||||
|
||||
graph.nodes[RESIZE] = resizeNode;
|
||||
|
||||
// The `RESIZE` node then passes its image to `IMAGE_TO_LATENTS`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'image' },
|
||||
destination: {
|
||||
node_id: IMAGE_TO_LATENTS,
|
||||
field: 'image',
|
||||
},
|
||||
});
|
||||
|
||||
// The `RESIZE` node also passes its width and height to `NOISE`
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: RESIZE, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
} else {
|
||||
// We are not resizing, so we need to set the image on the `IMAGE_TO_LATENTS` node explicitly
|
||||
(graph.nodes[IMAGE_TO_LATENTS] as ImageToLatentsInvocation).image = {
|
||||
image_name: initialImage.imageName,
|
||||
};
|
||||
|
||||
// Pass the image's dimensions to the `NOISE` node
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'width' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'width',
|
||||
},
|
||||
});
|
||||
graph.edges.push({
|
||||
source: { node_id: IMAGE_TO_LATENTS, field: 'height' },
|
||||
destination: {
|
||||
node_id: NOISE,
|
||||
field: 'height',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'sdxl_img2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
seed: 0, // set in addDynamicPromptsToGraph
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
clip_skip: clipSkip,
|
||||
strength: strength,
|
||||
init_image: initialImage.imageName,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (shouldUseSDXLRefiner) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_LATENTS_TO_LATENTS);
|
||||
}
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
@ -0,0 +1,251 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||
import { initialGenerationState } from 'features/parameters/store/generationSlice';
|
||||
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
|
||||
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
|
||||
import {
|
||||
LATENTS_TO_IMAGE,
|
||||
METADATA_ACCUMULATOR,
|
||||
NEGATIVE_CONDITIONING,
|
||||
NOISE,
|
||||
POSITIVE_CONDITIONING,
|
||||
SDXL_MODEL_LOADER,
|
||||
SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
SDXL_TEXT_TO_LATENTS,
|
||||
} from './constants';
|
||||
|
||||
export const buildLinearSDXLTextToImageGraph = (
|
||||
state: RootState
|
||||
): NonNullableGraph => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
negativePrompt,
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
shouldUseNoiseSettings,
|
||||
vaePrecision,
|
||||
} = state.generation;
|
||||
|
||||
const {
|
||||
positiveStylePrompt,
|
||||
negativeStylePrompt,
|
||||
shouldUseSDXLRefiner,
|
||||
refinerStart,
|
||||
} = state.sdxl;
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
? shouldUseCpuNoise
|
||||
: initialGenerationState.shouldUseCpuNoise;
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
}
|
||||
|
||||
/**
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
* The only thing we need extra logic for is handling randomized seed, control net, and for img2img,
|
||||
* the `fit` param. These are added to the graph at the end.
|
||||
*/
|
||||
|
||||
// copy-pasted graph from node editor, filled in with state values & friendly node ids
|
||||
const graph: NonNullableGraph = {
|
||||
id: SDXL_TEXT_TO_IMAGE_GRAPH,
|
||||
nodes: {
|
||||
[SDXL_MODEL_LOADER]: {
|
||||
type: 'sdxl_model_loader',
|
||||
id: SDXL_MODEL_LOADER,
|
||||
model,
|
||||
},
|
||||
[POSITIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: POSITIVE_CONDITIONING,
|
||||
prompt: positivePrompt,
|
||||
style: positiveStylePrompt,
|
||||
},
|
||||
[NEGATIVE_CONDITIONING]: {
|
||||
type: 'sdxl_compel_prompt',
|
||||
id: NEGATIVE_CONDITIONING,
|
||||
prompt: negativePrompt,
|
||||
style: negativeStylePrompt,
|
||||
},
|
||||
[NOISE]: {
|
||||
type: 'noise',
|
||||
id: NOISE,
|
||||
width,
|
||||
height,
|
||||
use_cpu,
|
||||
},
|
||||
[SDXL_TEXT_TO_LATENTS]: {
|
||||
type: 't2l_sdxl',
|
||||
id: SDXL_TEXT_TO_LATENTS,
|
||||
cfg_scale,
|
||||
scheduler,
|
||||
steps,
|
||||
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
|
||||
},
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_MODEL_LOADER,
|
||||
field: 'clip2',
|
||||
},
|
||||
destination: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: POSITIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'positive_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NEGATIVE_CONDITIONING,
|
||||
field: 'conditioning',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'negative_conditioning',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: NOISE,
|
||||
field: 'noise',
|
||||
},
|
||||
destination: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'noise',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: SDXL_TEXT_TO_LATENTS,
|
||||
field: 'latents',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'latents',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// add metadata accumulator, which is only mostly populated - some fields are added later
|
||||
graph.nodes[METADATA_ACCUMULATOR] = {
|
||||
id: METADATA_ACCUMULATOR,
|
||||
type: 'metadata_accumulator',
|
||||
generation_mode: 'sdxl_txt2img',
|
||||
cfg_scale,
|
||||
height,
|
||||
width,
|
||||
positive_prompt: '', // set in addDynamicPromptsToGraph
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
seed: 0, // set in addDynamicPromptsToGraph
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
scheduler,
|
||||
vae: undefined,
|
||||
controlnets: [],
|
||||
loras: [],
|
||||
clip_skip: clipSkip,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
};
|
||||
|
||||
graph.edges.push({
|
||||
source: {
|
||||
node_id: METADATA_ACCUMULATOR,
|
||||
field: 'metadata',
|
||||
},
|
||||
destination: {
|
||||
node_id: LATENTS_TO_IMAGE,
|
||||
field: 'metadata',
|
||||
},
|
||||
});
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (shouldUseSDXLRefiner) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);
|
||||
}
|
||||
|
||||
// add dynamic prompts - also sets up core iteration and seed
|
||||
addDynamicPromptsToGraph(state, graph);
|
||||
|
||||
return graph;
|
||||
};
|
@ -34,6 +34,7 @@ export const buildLinearTextToImageGraph = (
|
||||
clipSkip,
|
||||
shouldUseCpuNoise,
|
||||
shouldUseNoiseSettings,
|
||||
vaePrecision,
|
||||
} = state.generation;
|
||||
|
||||
const use_cpu = shouldUseNoiseSettings
|
||||
@ -95,6 +96,7 @@ export const buildLinearTextToImageGraph = (
|
||||
[LATENTS_TO_IMAGE]: {
|
||||
type: 'l2i',
|
||||
id: LATENTS_TO_IMAGE,
|
||||
fp32: vaePrecision === 'fp32' ? true : false,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
|
@ -23,8 +23,19 @@ export const METADATA_ACCUMULATOR = 'metadata_accumulator';
|
||||
export const REALESRGAN = 'esrgan';
|
||||
export const DIVIDE = 'divide';
|
||||
export const SCALE = 'scale_image';
|
||||
export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
|
||||
export const SDXL_TEXT_TO_LATENTS = 't2l_sdxl';
|
||||
export const SDXL_LATENTS_TO_LATENTS = 'l2l_sdxl';
|
||||
export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader';
|
||||
export const SDXL_REFINER_POSITIVE_CONDITIONING =
|
||||
'sdxl_refiner_positive_conditioning';
|
||||
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
|
||||
'sdxl_refiner_negative_conditioning';
|
||||
export const SDXL_REFINER_LATENTS_TO_LATENTS = 'l2l_sdxl_refiner';
|
||||
|
||||
// friendly graph ids
|
||||
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';
|
||||
export const SDXL_TEXT_TO_IMAGE_GRAPH = 'sdxl_text_to_image_graph';
|
||||
export const SDXL_IMAGE_TO_IMAGE_GRAPH = 'sxdl_image_to_image_graph';
|
||||
export const IMAGE_TO_IMAGE_GRAPH = 'image_to_image_graph';
|
||||
export const INPAINT_GRAPH = 'inpaint_graph';
|
||||
|
@ -4,6 +4,7 @@ import { memo } from 'react';
|
||||
import ParamMainModelSelect from '../MainModel/ParamMainModelSelect';
|
||||
import ParamVAEModelSelect from '../VAEModel/ParamVAEModelSelect';
|
||||
import ParamScheduler from './ParamScheduler';
|
||||
import ParamVAEPrecision from '../VAEModel/ParamVAEPrecision';
|
||||
|
||||
const ParamModelandVAEandScheduler = () => {
|
||||
const isVaeEnabled = useFeatureStatus('vae').isFeatureEnabled;
|
||||
@ -13,16 +14,15 @@ const ParamModelandVAEandScheduler = () => {
|
||||
<Box w="full">
|
||||
<ParamMainModelSelect />
|
||||
</Box>
|
||||
<Flex gap={3} w="full">
|
||||
{isVaeEnabled && (
|
||||
<Box w="full">
|
||||
<ParamVAEModelSelect />
|
||||
</Box>
|
||||
)}
|
||||
<Box w="full">
|
||||
<ParamScheduler />
|
||||
</Box>
|
||||
</Flex>
|
||||
<Box w="full">
|
||||
<ParamScheduler />
|
||||
</Box>
|
||||
{isVaeEnabled && (
|
||||
<Flex w="full" gap={3}>
|
||||
<ParamVAEModelSelect />
|
||||
<ParamVAEPrecision />
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -13,7 +13,9 @@ import { modelSelected } from 'features/parameters/store/actions';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { NON_REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import { useFeatureStatus } from '../../../../system/hooks/useFeatureStatus';
|
||||
|
||||
@ -29,8 +31,12 @@ const ParamMainModelSelect = () => {
|
||||
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
const { data: mainModels, isLoading } = useGetMainModelsQuery(
|
||||
NON_REFINER_BASE_MODELS
|
||||
);
|
||||
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!mainModels) {
|
||||
@ -40,7 +46,10 @@ const ParamMainModelSelect = () => {
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(mainModels.entities, (model, id) => {
|
||||
if (!model || ['sdxl', 'sdxl-refiner'].includes(model.base_model)) {
|
||||
if (
|
||||
!model ||
|
||||
(activeTabName === 'unifiedCanvas' && model.base_model === 'sdxl')
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -52,7 +61,7 @@ const ParamMainModelSelect = () => {
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [mainModels]);
|
||||
}, [mainModels, activeTabName]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
@ -88,7 +97,7 @@ const ParamMainModelSelect = () => {
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<Flex w="100%" alignItems="center" gap={2}>
|
||||
<Flex w="100%" alignItems="center" gap={3}>
|
||||
<IAIMantineSearchableSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={t('modelManager.model')}
|
||||
|
@ -32,11 +32,6 @@ export default function ParamSeed() {
|
||||
isInvalid={seed < 0 && shouldGenerateVariations}
|
||||
onChange={handleChangeSeed}
|
||||
value={seed}
|
||||
formControlProps={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 3, // really this should work with 2 but seems to need to be 3 to match gap 2?
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import ParamSeedRandomize from './ParamSeedRandomize';
|
||||
|
||||
const ParamSeedFull = () => {
|
||||
return (
|
||||
<Flex sx={{ gap: 4, alignItems: 'center' }}>
|
||||
<Flex sx={{ gap: 3, alignItems: 'flex-end' }}>
|
||||
<ParamSeed />
|
||||
<ParamSeedShuffle />
|
||||
<ParamSeedRandomize />
|
||||
|
@ -0,0 +1,46 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { vaePrecisionChanged } from 'features/parameters/store/generationSlice';
|
||||
import { PrecisionParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ generation }) => {
|
||||
const { vaePrecision } = generation;
|
||||
return { vaePrecision };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const DATA = ['fp16', 'fp32'];
|
||||
|
||||
const ParamVAEModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { vaePrecision } = useAppSelector(selector);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(vaePrecisionChanged(v as PrecisionParam));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
label="VAE Precision"
|
||||
value={vaePrecision}
|
||||
data={DATA}
|
||||
onChange={handleChange}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamVAEModelSelect);
|
@ -11,6 +11,7 @@ import {
|
||||
MainModelParam,
|
||||
NegativePromptParam,
|
||||
PositivePromptParam,
|
||||
PrecisionParam,
|
||||
SchedulerParam,
|
||||
SeedParam,
|
||||
StepsParam,
|
||||
@ -51,6 +52,7 @@ export interface GenerationState {
|
||||
verticalSymmetrySteps: number;
|
||||
model: MainModelField | null;
|
||||
vae: VaeModelParam | null;
|
||||
vaePrecision: PrecisionParam;
|
||||
seamlessXAxis: boolean;
|
||||
seamlessYAxis: boolean;
|
||||
clipSkip: number;
|
||||
@ -89,6 +91,7 @@ export const initialGenerationState: GenerationState = {
|
||||
verticalSymmetrySteps: 0,
|
||||
model: null,
|
||||
vae: null,
|
||||
vaePrecision: 'fp32',
|
||||
seamlessXAxis: false,
|
||||
seamlessYAxis: false,
|
||||
clipSkip: 0,
|
||||
@ -241,6 +244,9 @@ export const generationSlice = createSlice({
|
||||
// null is a valid VAE!
|
||||
state.vae = action.payload;
|
||||
},
|
||||
vaePrecisionChanged: (state, action: PayloadAction<PrecisionParam>) => {
|
||||
state.vaePrecision = action.payload;
|
||||
},
|
||||
setClipSkip: (state, action: PayloadAction<number>) => {
|
||||
state.clipSkip = action.payload;
|
||||
},
|
||||
@ -327,6 +333,7 @@ export const {
|
||||
shouldUseCpuNoiseChanged,
|
||||
setShouldShowAdvancedOptions,
|
||||
setAspectRatio,
|
||||
vaePrecisionChanged,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export default generationSlice.reducer;
|
||||
|
@ -42,6 +42,42 @@ export const isValidNegativePrompt = (
|
||||
val: unknown
|
||||
): val is NegativePromptParam => zNegativePrompt.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for SDXL positive style prompt parameter
|
||||
*/
|
||||
export const zPositiveStylePromptSDXL = z.string();
|
||||
/**
|
||||
* Type alias for SDXL positive style prompt parameter, inferred from its zod schema
|
||||
*/
|
||||
export type PositiveStylePromptSDXLParam = z.infer<
|
||||
typeof zPositiveStylePromptSDXL
|
||||
>;
|
||||
/**
|
||||
* Validates/type-guards a value as a SDXL positive style prompt parameter
|
||||
*/
|
||||
export const isValidSDXLPositiveStylePrompt = (
|
||||
val: unknown
|
||||
): val is PositiveStylePromptSDXLParam =>
|
||||
zPositiveStylePromptSDXL.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for SDXL negative style prompt parameter
|
||||
*/
|
||||
export const zNegativeStylePromptSDXL = z.string();
|
||||
/**
|
||||
* Type alias for SDXL negative style prompt parameter, inferred from its zod schema
|
||||
*/
|
||||
export type NegativeStylePromptSDXLParam = z.infer<
|
||||
typeof zNegativeStylePromptSDXL
|
||||
>;
|
||||
/**
|
||||
* Validates/type-guards a value as a SDXL negative style prompt parameter
|
||||
*/
|
||||
export const isValidSDXLNegativeStylePrompt = (
|
||||
val: unknown
|
||||
): val is NegativeStylePromptSDXLParam =>
|
||||
zNegativeStylePromptSDXL.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for steps parameter
|
||||
*/
|
||||
@ -260,6 +296,20 @@ export type StrengthParam = z.infer<typeof zStrength>;
|
||||
export const isValidStrength = (val: unknown): val is StrengthParam =>
|
||||
zStrength.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for a precision parameter
|
||||
*/
|
||||
export const zPrecision = z.enum(['fp16', 'fp32']);
|
||||
/**
|
||||
* Type alias for precision parameter, inferred from its zod schema
|
||||
*/
|
||||
export type PrecisionParam = z.infer<typeof zPrecision>;
|
||||
/**
|
||||
* Validates/type-guards a value as a precision parameter
|
||||
*/
|
||||
export const isValidPrecision = (val: unknown): val is PrecisionParam =>
|
||||
zPrecision.safeParse(val).success;
|
||||
|
||||
// /**
|
||||
// * Zod schema for BaseModelType
|
||||
// */
|
||||
|
@ -0,0 +1,53 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { setSDXLImg2ImgDenoisingStrength } from '../store/sdxlSlice';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl }) => {
|
||||
const { sdxlImg2ImgDenoisingStrength } = sdxl;
|
||||
|
||||
return {
|
||||
sdxlImg2ImgDenoisingStrength,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLImg2ImgDenoisingStrength = () => {
|
||||
const { sdxlImg2ImgDenoisingStrength } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setSDXLImg2ImgDenoisingStrength(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
dispatch(setSDXLImg2ImgDenoisingStrength(0.7));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label={`${t('parameters.denoisingStrength')}`}
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={sdxlImg2ImgDenoisingStrength}
|
||||
isInteger={false}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLImg2ImgDenoisingStrength);
|
@ -0,0 +1,149 @@
|
||||
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { setNegativeStylePromptSDXL } from '../store/sdxlSlice';
|
||||
|
||||
const promptInputSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ sdxl }, activeTabName) => {
|
||||
return {
|
||||
prompt: sdxl.negativeStylePrompt,
|
||||
activeTabName,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Prompt input text area.
|
||||
*/
|
||||
const ParamSDXLNegativeStyleConditioning = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
|
||||
const handleChangePrompt = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setNegativeStylePromptSDXL(e.target.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleSelectEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
if (!promptRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
const caret = promptRef.current.selectionStart;
|
||||
|
||||
if (caret === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
|
||||
// must flush dom updates else selection gets reset
|
||||
flushSync(() => {
|
||||
dispatch(setNegativeStylePromptSDXL(newPrompt));
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
promptRef.current.selectionStart = finalCaretPos;
|
||||
promptRef.current.selectionEnd = finalCaretPos;
|
||||
onClose();
|
||||
},
|
||||
[dispatch, onClose, prompt]
|
||||
);
|
||||
|
||||
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
||||
e.preventDefault();
|
||||
dispatch(clampSymmetrySteps());
|
||||
dispatch(userInvoked(activeTabName));
|
||||
}
|
||||
if (isEmbeddingEnabled && e.key === '<') {
|
||||
onOpen();
|
||||
}
|
||||
},
|
||||
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
|
||||
);
|
||||
|
||||
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
|
||||
// const target = e.target as HTMLTextAreaElement;
|
||||
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
|
||||
// };
|
||||
|
||||
return (
|
||||
<Box position="relative">
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
ref={promptRef}
|
||||
value={prompt}
|
||||
placeholder="Negative Style Prompt"
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
resize="vertical"
|
||||
fontSize="sm"
|
||||
minH={16}
|
||||
/>
|
||||
</ParamEmbeddingPopover>
|
||||
</FormControl>
|
||||
{!isOpen && isEmbeddingEnabled && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamSDXLNegativeStyleConditioning;
|
@ -0,0 +1,148 @@
|
||||
import { Box, FormControl, useDisclosure } from '@chakra-ui/react';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react';
|
||||
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { clampSymmetrySteps } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import IAITextarea from 'common/components/IAITextarea';
|
||||
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
|
||||
import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton';
|
||||
import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { flushSync } from 'react-dom';
|
||||
import { setPositiveStylePromptSDXL } from '../store/sdxlSlice';
|
||||
|
||||
const promptInputSelector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ sdxl }, activeTabName) => {
|
||||
return {
|
||||
prompt: sdxl.positiveStylePrompt,
|
||||
activeTabName,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
/**
|
||||
* Prompt input text area.
|
||||
*/
|
||||
const ParamSDXLPositiveStyleConditioning = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { prompt, activeTabName } = useAppSelector(promptInputSelector);
|
||||
const isReady = useIsReadyToInvoke();
|
||||
const promptRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { isOpen, onClose, onOpen } = useDisclosure();
|
||||
|
||||
const handleChangePrompt = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setPositiveStylePromptSDXL(e.target.value));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleSelectEmbedding = useCallback(
|
||||
(v: string) => {
|
||||
if (!promptRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
// this is where we insert the TI trigger
|
||||
const caret = promptRef.current.selectionStart;
|
||||
|
||||
if (caret === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
let newPrompt = prompt.slice(0, caret);
|
||||
|
||||
if (newPrompt[newPrompt.length - 1] !== '<') {
|
||||
newPrompt += '<';
|
||||
}
|
||||
|
||||
newPrompt += `${v}>`;
|
||||
|
||||
// we insert the cursor after the `>`
|
||||
const finalCaretPos = newPrompt.length;
|
||||
|
||||
newPrompt += prompt.slice(caret);
|
||||
|
||||
// must flush dom updates else selection gets reset
|
||||
flushSync(() => {
|
||||
dispatch(setPositiveStylePromptSDXL(newPrompt));
|
||||
});
|
||||
|
||||
// set the caret position to just after the TI trigger
|
||||
promptRef.current.selectionStart = finalCaretPos;
|
||||
promptRef.current.selectionEnd = finalCaretPos;
|
||||
onClose();
|
||||
},
|
||||
[dispatch, onClose, prompt]
|
||||
);
|
||||
|
||||
const isEmbeddingEnabled = useFeatureStatus('embedding').isFeatureEnabled;
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === 'Enter' && e.shiftKey === false && isReady) {
|
||||
e.preventDefault();
|
||||
dispatch(clampSymmetrySteps());
|
||||
dispatch(userInvoked(activeTabName));
|
||||
}
|
||||
if (isEmbeddingEnabled && e.key === '<') {
|
||||
onOpen();
|
||||
}
|
||||
},
|
||||
[isReady, dispatch, activeTabName, onOpen, isEmbeddingEnabled]
|
||||
);
|
||||
|
||||
// const handleSelect = (e: MouseEvent<HTMLTextAreaElement>) => {
|
||||
// const target = e.target as HTMLTextAreaElement;
|
||||
// setCaret({ start: target.selectionStart, end: target.selectionEnd });
|
||||
// };
|
||||
|
||||
return (
|
||||
<Box position="relative">
|
||||
<FormControl>
|
||||
<ParamEmbeddingPopover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onSelect={handleSelectEmbedding}
|
||||
>
|
||||
<IAITextarea
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
ref={promptRef}
|
||||
value={prompt}
|
||||
placeholder="Positive Style Prompt"
|
||||
onChange={handleChangePrompt}
|
||||
onKeyDown={handleKeyDown}
|
||||
resize="vertical"
|
||||
minH={16}
|
||||
/>
|
||||
</ParamEmbeddingPopover>
|
||||
</FormControl>
|
||||
{!isOpen && isEmbeddingEnabled && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
insetInlineEnd: 0,
|
||||
}}
|
||||
>
|
||||
<AddEmbeddingButton onClick={onOpen} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamSDXLPositiveStyleConditioning;
|
@ -0,0 +1,48 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import ParamSDXLRefinerAestheticScore from './SDXLRefiner/ParamSDXLRefinerAestheticScore';
|
||||
import ParamSDXLRefinerCFGScale from './SDXLRefiner/ParamSDXLRefinerCFGScale';
|
||||
import ParamSDXLRefinerModelSelect from './SDXLRefiner/ParamSDXLRefinerModelSelect';
|
||||
import ParamSDXLRefinerScheduler from './SDXLRefiner/ParamSDXLRefinerScheduler';
|
||||
import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
|
||||
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
|
||||
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => {
|
||||
const { shouldUseSDXLRefiner } = state.sdxl;
|
||||
const { shouldUseSliders } = state.ui;
|
||||
return {
|
||||
activeLabel: shouldUseSDXLRefiner ? 'Enabled' : undefined,
|
||||
shouldUseSliders,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerCollapse = () => {
|
||||
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
|
||||
|
||||
return (
|
||||
<IAICollapse label="Refiner" activeLabel={activeLabel}>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<ParamUseSDXLRefiner />
|
||||
<ParamSDXLRefinerModelSelect />
|
||||
<Flex gap={2} flexDirection={shouldUseSliders ? 'column' : 'row'}>
|
||||
<ParamSDXLRefinerSteps />
|
||||
<ParamSDXLRefinerCFGScale />
|
||||
</Flex>
|
||||
<ParamSDXLRefinerScheduler />
|
||||
<ParamSDXLRefinerAestheticScore />
|
||||
<ParamSDXLRefinerStart />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default ParamSDXLRefinerCollapse;
|
@ -0,0 +1,78 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAICollapse from 'common/components/IAICollapse';
|
||||
import ParamCFGScale from 'features/parameters/components/Parameters/Core/ParamCFGScale';
|
||||
import ParamIterations from 'features/parameters/components/Parameters/Core/ParamIterations';
|
||||
import ParamModelandVAEandScheduler from 'features/parameters/components/Parameters/Core/ParamModelandVAEandScheduler';
|
||||
import ParamSize from 'features/parameters/components/Parameters/Core/ParamSize';
|
||||
import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps';
|
||||
import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit';
|
||||
import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
import ParamSDXLImg2ImgDenoisingStrength from './ParamSDXLImg2ImgDenoisingStrength';
|
||||
|
||||
const selector = createSelector(
|
||||
[uiSelector, generationSelector],
|
||||
(ui, generation) => {
|
||||
const { shouldUseSliders } = ui;
|
||||
const { shouldRandomizeSeed } = generation;
|
||||
|
||||
const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined;
|
||||
|
||||
return { shouldUseSliders, activeLabel };
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const SDXLImageToImageTabCoreParameters = () => {
|
||||
const { shouldUseSliders, activeLabel } = useAppSelector(selector);
|
||||
|
||||
return (
|
||||
<IAICollapse
|
||||
label={'General'}
|
||||
activeLabel={activeLabel}
|
||||
defaultIsOpen={true}
|
||||
>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
gap: 3,
|
||||
}}
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
<ParamModelandVAEandScheduler />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
<ParamSize />
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Flex gap={3}>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
<ParamModelandVAEandScheduler />
|
||||
<Box pt={2}>
|
||||
<ParamSeedFull />
|
||||
</Box>
|
||||
<ParamSize />
|
||||
</>
|
||||
)}
|
||||
<ParamSDXLImg2ImgDenoisingStrength />
|
||||
<ImageToImageFit />
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SDXLImageToImageTabCoreParameters);
|
@ -0,0 +1,28 @@
|
||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
// import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||
import SDXLImageToImageTabCoreParameters from './SDXLImageToImageTabCoreParameters';
|
||||
|
||||
const SDXLImageToImageTabParameters = () => {
|
||||
return (
|
||||
<>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamSDXLPositiveStyleConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamSDXLNegativeStyleConditioning />
|
||||
<ProcessButtons />
|
||||
<SDXLImageToImageTabCoreParameters />
|
||||
<ParamSDXLRefinerCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamNoiseCollapse />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default SDXLImageToImageTabParameters;
|
@ -0,0 +1,60 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setRefinerAestheticScore } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, hotkeys }) => {
|
||||
const { refinerAestheticScore } = sdxl;
|
||||
const { shift } = hotkeys;
|
||||
|
||||
return {
|
||||
refinerAestheticScore,
|
||||
shift,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerAestheticScore = () => {
|
||||
const { refinerAestheticScore, shift } = useAppSelector(selector);
|
||||
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setRefinerAestheticScore(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(
|
||||
() => dispatch(setRefinerAestheticScore(6)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label="Aesthetic Score"
|
||||
step={shift ? 0.1 : 0.5}
|
||||
min={1}
|
||||
max={10}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={refinerAestheticScore}
|
||||
sliderNumberInputProps={{ max: 10 }}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
isInteger={false}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerAestheticScore);
|
@ -0,0 +1,75 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setRefinerCFGScale } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, ui, hotkeys }) => {
|
||||
const { refinerCFGScale } = sdxl;
|
||||
const { shouldUseSliders } = ui;
|
||||
const { shift } = hotkeys;
|
||||
|
||||
return {
|
||||
refinerCFGScale,
|
||||
shouldUseSliders,
|
||||
shift,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerCFGScale = () => {
|
||||
const { refinerCFGScale, shouldUseSliders, shift } = useAppSelector(selector);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setRefinerCFGScale(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(
|
||||
() => dispatch(setRefinerCFGScale(7)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return shouldUseSliders ? (
|
||||
<IAISlider
|
||||
label={t('parameters.cfgScale')}
|
||||
step={shift ? 0.1 : 0.5}
|
||||
min={1}
|
||||
max={20}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={refinerCFGScale}
|
||||
sliderNumberInputProps={{ max: 200 }}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
isInteger={false}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
) : (
|
||||
<IAINumberInput
|
||||
label={t('parameters.cfgScale')}
|
||||
step={0.5}
|
||||
min={1}
|
||||
max={200}
|
||||
onChange={handleChange}
|
||||
value={refinerCFGScale}
|
||||
isInteger={false}
|
||||
numberInputFieldProps={{ textAlign: 'center' }}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerCFGScale);
|
@ -0,0 +1,111 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
|
||||
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import SyncModelsButton from 'features/ui/components/tabs/ModelManager/subpanels/ModelManagerSettingsPanel/SyncModelsButton';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
(state) => ({ model: state.sdxl.refinerModel }),
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const { data: refinerModels, isLoading } =
|
||||
useGetMainModelsQuery(REFINER_BASE_MODELS);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!refinerModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(refinerModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [refinerModels]);
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
// TODO: maybe we should just store the full model entity in state?
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
refinerModels?.entities[
|
||||
`${model?.base_model}/main/${model?.model_name}`
|
||||
] ?? null,
|
||||
[refinerModels?.entities, model]
|
||||
);
|
||||
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newModel = modelIdToMainModelParam(v);
|
||||
|
||||
if (!newModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(refinerModelChanged(newModel));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return isLoading ? (
|
||||
<IAIMantineSearchableSelect
|
||||
label="Refiner Model"
|
||||
placeholder="Loading..."
|
||||
disabled={true}
|
||||
data={[]}
|
||||
/>
|
||||
) : (
|
||||
<Flex w="100%" alignItems="center" gap={2}>
|
||||
<IAIMantineSearchableSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label="Refiner Model"
|
||||
value={selectedModel?.id}
|
||||
placeholder={data.length > 0 ? 'Select a model' : 'No models available'}
|
||||
data={data}
|
||||
error={data.length === 0}
|
||||
disabled={data.length === 0}
|
||||
onChange={handleChangeModel}
|
||||
w="100%"
|
||||
/>
|
||||
{isSyncModelEnabled && (
|
||||
<Box mt={7}>
|
||||
<SyncModelsButton iconMode />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerModelSelect);
|
@ -0,0 +1,65 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
|
||||
import {
|
||||
SCHEDULER_LABEL_MAP,
|
||||
SchedulerParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
({ ui, sdxl }) => {
|
||||
const { refinerScheduler } = sdxl;
|
||||
const { favoriteSchedulers: enabledSchedulers } = ui;
|
||||
|
||||
const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({
|
||||
value: name,
|
||||
label: label,
|
||||
group: enabledSchedulers.includes(name as SchedulerParam)
|
||||
? 'Favorites'
|
||||
: undefined,
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return {
|
||||
refinerScheduler,
|
||||
data,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerScheduler = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { refinerScheduler, data } = useAppSelector(selector);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
const handleChange = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
dispatch(setRefinerScheduler(v as SchedulerParam));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSearchableSelect
|
||||
w="100%"
|
||||
label={t('parameters.scheduler')}
|
||||
value={refinerScheduler}
|
||||
data={data}
|
||||
onChange={handleChange}
|
||||
disabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerScheduler);
|
@ -0,0 +1,53 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setRefinerStart } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl }) => {
|
||||
const { refinerStart } = sdxl;
|
||||
return {
|
||||
refinerStart,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerStart = () => {
|
||||
const { refinerStart } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
const handleChange = useCallback(
|
||||
(v: number) => dispatch(setRefinerStart(v)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(
|
||||
() => dispatch(setRefinerStart(0.7)),
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
label="Refiner Start"
|
||||
step={0.01}
|
||||
min={0}
|
||||
max={1}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={refinerStart}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
isInteger={false}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerStart);
|
@ -0,0 +1,72 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setRefinerSteps } from 'features/sdxl/store/sdxlSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
({ sdxl, ui }) => {
|
||||
const { refinerSteps } = sdxl;
|
||||
const { shouldUseSliders } = ui;
|
||||
|
||||
return {
|
||||
refinerSteps,
|
||||
shouldUseSliders,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamSDXLRefinerSteps = () => {
|
||||
const { refinerSteps, shouldUseSliders } = useAppSelector(selector);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setRefinerSteps(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const handleReset = useCallback(() => {
|
||||
dispatch(setRefinerSteps(20));
|
||||
}, [dispatch]);
|
||||
|
||||
return shouldUseSliders ? (
|
||||
<IAISlider
|
||||
label={t('parameters.steps')}
|
||||
min={1}
|
||||
max={100}
|
||||
step={1}
|
||||
onChange={handleChange}
|
||||
handleReset={handleReset}
|
||||
value={refinerSteps}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
sliderNumberInputProps={{ max: 500 }}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
) : (
|
||||
<IAINumberInput
|
||||
label={t('parameters.steps')}
|
||||
min={1}
|
||||
max={500}
|
||||
step={1}
|
||||
onChange={handleChange}
|
||||
value={refinerSteps}
|
||||
numberInputFieldProps={{ textAlign: 'center' }}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamSDXLRefinerSteps);
|
@ -0,0 +1,28 @@
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { setShouldUseSDXLRefiner } from 'features/sdxl/store/sdxlSlice';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
export default function ParamUseSDXLRefiner() {
|
||||
const shouldUseSDXLRefiner = useAppSelector(
|
||||
(state: RootState) => state.sdxl.shouldUseSDXLRefiner
|
||||
);
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleUseSDXLRefinerChange = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(setShouldUseSDXLRefiner(e.target.checked));
|
||||
};
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
label="Use Refiner"
|
||||
isChecked={shouldUseSDXLRefiner}
|
||||
onChange={handleUseSDXLRefinerChange}
|
||||
isDisabled={!isRefinerAvailable}
|
||||
/>
|
||||
);
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse';
|
||||
import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning';
|
||||
import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning';
|
||||
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
|
||||
import ParamSDXLNegativeStyleConditioning from './ParamSDXLNegativeStyleConditioning';
|
||||
import ParamSDXLPositiveStyleConditioning from './ParamSDXLPositiveStyleConditioning';
|
||||
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
|
||||
|
||||
const SDXLTextToImageTabParameters = () => {
|
||||
return (
|
||||
<>
|
||||
<ParamPositiveConditioning />
|
||||
<ParamSDXLPositiveStyleConditioning />
|
||||
<ParamNegativeConditioning />
|
||||
<ParamSDXLNegativeStyleConditioning />
|
||||
<ProcessButtons />
|
||||
<TextToImageTabCoreParameters />
|
||||
<ParamSDXLRefinerCollapse />
|
||||
<ParamDynamicPromptsCollapse />
|
||||
<ParamNoiseCollapse />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default SDXLTextToImageTabParameters;
|
89
invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts
Normal file
89
invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts
Normal file
@ -0,0 +1,89 @@
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import {
|
||||
MainModelParam,
|
||||
NegativeStylePromptSDXLParam,
|
||||
PositiveStylePromptSDXLParam,
|
||||
SchedulerParam,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { MainModelField } from 'services/api/types';
|
||||
|
||||
type SDXLInitialState = {
|
||||
positiveStylePrompt: PositiveStylePromptSDXLParam;
|
||||
negativeStylePrompt: NegativeStylePromptSDXLParam;
|
||||
shouldUseSDXLRefiner: boolean;
|
||||
sdxlImg2ImgDenoisingStrength: number;
|
||||
refinerModel: MainModelField | null;
|
||||
refinerSteps: number;
|
||||
refinerCFGScale: number;
|
||||
refinerScheduler: SchedulerParam;
|
||||
refinerAestheticScore: number;
|
||||
refinerStart: number;
|
||||
};
|
||||
|
||||
const sdxlInitialState: SDXLInitialState = {
|
||||
positiveStylePrompt: '',
|
||||
negativeStylePrompt: '',
|
||||
shouldUseSDXLRefiner: false,
|
||||
sdxlImg2ImgDenoisingStrength: 0.7,
|
||||
refinerModel: null,
|
||||
refinerSteps: 20,
|
||||
refinerCFGScale: 7.5,
|
||||
refinerScheduler: 'euler',
|
||||
refinerAestheticScore: 6,
|
||||
refinerStart: 0.7,
|
||||
};
|
||||
|
||||
const sdxlSlice = createSlice({
|
||||
name: 'sdxl',
|
||||
initialState: sdxlInitialState,
|
||||
reducers: {
|
||||
setPositiveStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
||||
state.positiveStylePrompt = action.payload;
|
||||
},
|
||||
setNegativeStylePromptSDXL: (state, action: PayloadAction<string>) => {
|
||||
state.negativeStylePrompt = action.payload;
|
||||
},
|
||||
setShouldUseSDXLRefiner: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldUseSDXLRefiner = action.payload;
|
||||
},
|
||||
setSDXLImg2ImgDenoisingStrength: (state, action: PayloadAction<number>) => {
|
||||
state.sdxlImg2ImgDenoisingStrength = action.payload;
|
||||
},
|
||||
refinerModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<MainModelParam | null>
|
||||
) => {
|
||||
state.refinerModel = action.payload;
|
||||
},
|
||||
setRefinerSteps: (state, action: PayloadAction<number>) => {
|
||||
state.refinerSteps = action.payload;
|
||||
},
|
||||
setRefinerCFGScale: (state, action: PayloadAction<number>) => {
|
||||
state.refinerCFGScale = action.payload;
|
||||
},
|
||||
setRefinerScheduler: (state, action: PayloadAction<SchedulerParam>) => {
|
||||
state.refinerScheduler = action.payload;
|
||||
},
|
||||
setRefinerAestheticScore: (state, action: PayloadAction<number>) => {
|
||||
state.refinerAestheticScore = action.payload;
|
||||
},
|
||||
setRefinerStart: (state, action: PayloadAction<number>) => {
|
||||
state.refinerStart = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
setPositiveStylePromptSDXL,
|
||||
setNegativeStylePromptSDXL,
|
||||
setShouldUseSDXLRefiner,
|
||||
setSDXLImg2ImgDenoisingStrength,
|
||||
refinerModelChanged,
|
||||
setRefinerSteps,
|
||||
setRefinerCFGScale,
|
||||
setRefinerScheduler,
|
||||
setRefinerAestheticScore,
|
||||
setRefinerStart,
|
||||
} = sdxlSlice.actions;
|
||||
|
||||
export default sdxlSlice.reducer;
|
@ -16,7 +16,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
||||
import { configSelector } from 'features/system/store/configSelectors';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
@ -172,13 +172,22 @@ const InvokeTabs = () => {
|
||||
const { ref: galleryPanelRef, minSizePct: galleryMinSizePct } =
|
||||
useMinimumPanelSize(MIN_GALLERY_WIDTH, DEFAULT_GALLERY_PCT, 'app');
|
||||
|
||||
const handleTabChange = useCallback(
|
||||
(index: number) => {
|
||||
const activeTabName = tabMap[index];
|
||||
if (!activeTabName) {
|
||||
return;
|
||||
}
|
||||
dispatch(setActiveTab(activeTabName));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<Tabs
|
||||
defaultIndex={activeTab}
|
||||
index={activeTab}
|
||||
onChange={(index: number) => {
|
||||
dispatch(setActiveTab(index));
|
||||
}}
|
||||
onChange={handleTabChange}
|
||||
sx={{
|
||||
flexGrow: 1,
|
||||
gap: 4,
|
||||
|
@ -1,7 +1,9 @@
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import InitialImageDisplay from 'features/parameters/components/Parameters/ImageToImage/InitialImageDisplay';
|
||||
import SDXLImageToImageTabParameters from 'features/sdxl/components/SDXLImageToImageTabParameters';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import {
|
||||
ImperativePanelGroupHandle,
|
||||
@ -16,6 +18,7 @@ import ImageToImageTabParameters from './ImageToImageTabParameters';
|
||||
const ImageToImageTab = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||
const model = useAppSelector((state: RootState) => state.generation.model);
|
||||
|
||||
const handleDoubleClickHandle = useCallback(() => {
|
||||
if (!panelGroupRef.current) {
|
||||
@ -28,7 +31,11 @@ const ImageToImageTab = () => {
|
||||
return (
|
||||
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
|
||||
<ParametersPinnedWrapper>
|
||||
<ImageToImageTabParameters />
|
||||
{model && model.base_model === 'sdxl' ? (
|
||||
<SDXLImageToImageTabParameters />
|
||||
) : (
|
||||
<ImageToImageTabParameters />
|
||||
)}
|
||||
</ParametersPinnedWrapper>
|
||||
<Box sx={{ w: 'full', h: 'full' }}>
|
||||
<PanelGroup
|
||||
|
@ -16,6 +16,7 @@ import {
|
||||
useImportMainModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
import { setAdvancedAddScanModel } from '../../store/modelManagerSlice';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
export default function FoundModelsList() {
|
||||
const searchFolder = useAppSelector(
|
||||
@ -24,7 +25,7 @@ export default function FoundModelsList() {
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
|
||||
// Get paths of models that are already installed
|
||||
const { data: installedModels } = useGetMainModelsQuery();
|
||||
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
// Get all model paths from a given directory
|
||||
const { foundModels, alreadyInstalled, filteredModels } =
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
@ -8,9 +7,11 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { pickBy } from 'lodash-es';
|
||||
import { useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useMergeMainModelsMutation,
|
||||
@ -32,7 +33,7 @@ export default function MergeModelsPanel() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useGetMainModelsQuery();
|
||||
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
||||
|
||||
|
@ -8,10 +8,11 @@ import {
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
export default function ModelManagerPanel() {
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||
const { model } = useGetMainModelsQuery(undefined, {
|
||||
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
useGetMainModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
type ModelListProps = {
|
||||
selectedModelId: string | undefined;
|
||||
@ -26,13 +27,13 @@ const ModelList = (props: ModelListProps) => {
|
||||
const [modelFormatFilter, setModelFormatFilter] =
|
||||
useState<ModelFormat>('images');
|
||||
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
||||
}),
|
||||
|
@ -1,14 +1,22 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import TextToImageSDXLTabParameters from 'features/sdxl/components/SDXLTextToImageTabParameters';
|
||||
import { memo } from 'react';
|
||||
import ParametersPinnedWrapper from '../../ParametersPinnedWrapper';
|
||||
import TextToImageTabMain from './TextToImageTabMain';
|
||||
import TextToImageTabParameters from './TextToImageTabParameters';
|
||||
|
||||
const TextToImageTab = () => {
|
||||
const model = useAppSelector((state: RootState) => state.generation.model);
|
||||
return (
|
||||
<Flex sx={{ gap: 4, w: 'full', h: 'full' }}>
|
||||
<ParametersPinnedWrapper>
|
||||
<TextToImageTabParameters />
|
||||
{model && model.base_model === 'sdxl' ? (
|
||||
<TextToImageSDXLTabParameters />
|
||||
) : (
|
||||
<TextToImageTabParameters />
|
||||
)}
|
||||
</ParametersPinnedWrapper>
|
||||
<TextToImageTabMain />
|
||||
</Flex>
|
||||
|
@ -26,7 +26,7 @@ export const uiSlice = createSlice({
|
||||
name: 'ui',
|
||||
initialState: initialUIState,
|
||||
reducers: {
|
||||
setActiveTab: (state, action: PayloadAction<number | InvokeTabName>) => {
|
||||
setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
|
||||
setActiveTabReducer(state, action.payload);
|
||||
},
|
||||
setShouldPinParametersPanel: (state, action: PayloadAction<boolean>) => {
|
||||
|
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
16
invokeai/frontend/web/src/services/api/constants.ts
Normal file
@ -0,0 +1,16 @@
|
||||
import { BaseModelType } from './types';
|
||||
|
||||
export const ALL_BASE_MODELS: BaseModelType[] = [
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
'sdxl-refiner',
|
||||
];
|
||||
|
||||
export const NON_REFINER_BASE_MODELS: BaseModelType[] = [
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
];
|
||||
|
||||
export const REFINER_BASE_MODELS: BaseModelType[] = ['sdxl-refiner'];
|
@ -144,8 +144,19 @@ const createModelEntities = <T extends AnyModelConfigEntity>(
|
||||
|
||||
export const modelsApi = api.injectEndpoints({
|
||||
endpoints: (build) => ({
|
||||
getMainModels: build.query<EntityState<MainModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
|
||||
getMainModels: build.query<
|
||||
EntityState<MainModelConfigEntity>,
|
||||
BaseModelType[]
|
||||
>({
|
||||
query: (base_models) => {
|
||||
const params = {
|
||||
model_type: 'main',
|
||||
base_models,
|
||||
};
|
||||
|
||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||
return `models/?${query}`;
|
||||
},
|
||||
providesTags: (result, error, arg) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
@ -187,7 +198,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
importMainModels: build.mutation<
|
||||
ImportMainModelResponse,
|
||||
@ -200,7 +214,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
addMainModels: build.mutation<AddMainModelResponse, AddMainModelArg>({
|
||||
query: ({ body }) => {
|
||||
@ -210,7 +227,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
deleteMainModels: build.mutation<
|
||||
DeleteMainModelResponse,
|
||||
@ -222,7 +242,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
convertMainModels: build.mutation<
|
||||
ConvertMainModelResponse,
|
||||
@ -235,7 +258,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
params: params,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||
query: ({ base_model, body }) => {
|
||||
@ -245,7 +271,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
syncModels: build.mutation<SyncModelsResponse, void>({
|
||||
query: () => {
|
||||
@ -254,7 +283,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
method: 'POST',
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
invalidatesTags: [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
{ type: 'SDXLRefinerModel', id: LIST_TAG },
|
||||
],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||
|
@ -0,0 +1,12 @@
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
export const useIsRefinerAvailable = () => {
|
||||
const { isRefinerAvailable } = useGetMainModelsQuery(REFINER_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
isRefinerAvailable: data ? data.ids.length > 0 : false,
|
||||
}),
|
||||
});
|
||||
|
||||
return isRefinerAvailable;
|
||||
};
|
@ -1014,6 +1014,11 @@ export type components = {
|
||||
* @description The LoRAs used for inference
|
||||
*/
|
||||
loras: (components["schemas"]["LoRAMetadataField"])[];
|
||||
/**
|
||||
* Vae
|
||||
* @description The VAE used for decoding, if the main model's default was not used
|
||||
*/
|
||||
vae?: components["schemas"]["VAEModelField"];
|
||||
/**
|
||||
* Strength
|
||||
* @description The strength used for latents-to-latents
|
||||
@ -1025,10 +1030,45 @@ export type components = {
|
||||
*/
|
||||
init_image?: string;
|
||||
/**
|
||||
* Vae
|
||||
* @description The VAE used for decoding, if the main model's default was not used
|
||||
* Positive Style Prompt
|
||||
* @description The positive style prompt parameter
|
||||
*/
|
||||
vae?: components["schemas"]["VAEModelField"];
|
||||
positive_style_prompt?: string;
|
||||
/**
|
||||
* Negative Style Prompt
|
||||
* @description The negative style prompt parameter
|
||||
*/
|
||||
negative_style_prompt?: string;
|
||||
/**
|
||||
* Refiner Model
|
||||
* @description The SDXL Refiner model used
|
||||
*/
|
||||
refiner_model?: components["schemas"]["MainModelField"];
|
||||
/**
|
||||
* Refiner Cfg Scale
|
||||
* @description The classifier-free guidance scale parameter used for the refiner
|
||||
*/
|
||||
refiner_cfg_scale?: number;
|
||||
/**
|
||||
* Refiner Steps
|
||||
* @description The number of steps used for the refiner
|
||||
*/
|
||||
refiner_steps?: number;
|
||||
/**
|
||||
* Refiner Scheduler
|
||||
* @description The scheduler used for the refiner
|
||||
*/
|
||||
refiner_scheduler?: string;
|
||||
/**
|
||||
* Refiner Aesthetic Store
|
||||
* @description The aesthetic score used for the refiner
|
||||
*/
|
||||
refiner_aesthetic_store?: number;
|
||||
/**
|
||||
* Refiner Start
|
||||
* @description The start value used for refiner denoising
|
||||
*/
|
||||
refiner_start?: number;
|
||||
};
|
||||
/**
|
||||
* CvInpaintInvocation
|
||||
@ -3268,6 +3308,46 @@ export type components = {
|
||||
* @description The VAE used for decoding, if the main model's default was not used
|
||||
*/
|
||||
vae?: components["schemas"]["VAEModelField"];
|
||||
/**
|
||||
* Positive Style Prompt
|
||||
* @description The positive style prompt parameter
|
||||
*/
|
||||
positive_style_prompt?: string;
|
||||
/**
|
||||
* Negative Style Prompt
|
||||
* @description The negative style prompt parameter
|
||||
*/
|
||||
negative_style_prompt?: string;
|
||||
/**
|
||||
* Refiner Model
|
||||
* @description The SDXL Refiner model used
|
||||
*/
|
||||
refiner_model?: components["schemas"]["MainModelField"];
|
||||
/**
|
||||
* Refiner Cfg Scale
|
||||
* @description The classifier-free guidance scale parameter used for the refiner
|
||||
*/
|
||||
refiner_cfg_scale?: number;
|
||||
/**
|
||||
* Refiner Steps
|
||||
* @description The number of steps used for the refiner
|
||||
*/
|
||||
refiner_steps?: number;
|
||||
/**
|
||||
* Refiner Scheduler
|
||||
* @description The scheduler used for the refiner
|
||||
*/
|
||||
refiner_scheduler?: string;
|
||||
/**
|
||||
* Refiner Aesthetic Store
|
||||
* @description The aesthetic score used for the refiner
|
||||
*/
|
||||
refiner_aesthetic_store?: number;
|
||||
/**
|
||||
* Refiner Start
|
||||
* @description The start value used for refiner denoising
|
||||
*/
|
||||
refiner_start?: number;
|
||||
};
|
||||
/**
|
||||
* MetadataAccumulatorOutput
|
||||
@ -5355,6 +5435,12 @@ export type components = {
|
||||
*/
|
||||
image?: components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
@ -5367,12 +5453,6 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
|
Loading…
Reference in New Issue
Block a user