feat: Add SDXL Refiner to Linear UI

This commit is contained in:
blessedcoolant 2023-07-25 23:13:37 +12:00 committed by psychedelicious
parent 5202610160
commit 6295e56d96
4 changed files with 172 additions and 13 deletions

View File

@ -7,6 +7,7 @@ import {
ImageToLatentsInvocation,
} from 'services/api/types';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addSDXLRefinerToGraph } from './buildSDXLRefinerGraph';
import {
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
@ -44,7 +45,12 @@ export const buildLinearSDXLImageToImageGraph = (
shouldUseNoiseSettings,
} = state.generation;
const { positiveStylePrompt, negativeStylePrompt } = state.sdxl;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
} = state.sdxl;
// TODO: add batch functionality
// const {
@ -115,7 +121,7 @@ export const buildLinearSDXLImageToImageGraph = (
cfg_scale,
scheduler,
steps,
denoising_start: 1 - strength,
denoising_start: shouldUseSDXLRefiner ? refinerStart : 1 - strength,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
@ -389,6 +395,11 @@ export const buildLinearSDXLImageToImageGraph = (
},
});
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, LATENTS_TO_IMAGE);
}
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);

View File

@ -3,6 +3,7 @@ import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addSDXLRefinerToGraph } from './buildSDXLRefinerGraph';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
@ -32,7 +33,12 @@ export const buildLinearSDXLTextToImageGraph = (
shouldUseNoiseSettings,
} = state.generation;
const { positiveStylePrompt, negativeStylePrompt } = state.sdxl;
const {
positiveStylePrompt,
negativeStylePrompt,
shouldUseSDXLRefiner,
refinerStart,
} = state.sdxl;
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
@ -86,6 +92,7 @@ export const buildLinearSDXLTextToImageGraph = (
cfg_scale,
scheduler,
steps,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
@ -228,6 +235,11 @@ export const buildLinearSDXLTextToImageGraph = (
},
});
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_TEXT_TO_LATENTS);
}
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);

View File

@ -1,32 +1,163 @@
import { RootState } from 'app/store/store';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { METADATA_ACCUMULATOR, SDXL_TEXT_TO_LATENTS } from './constants';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
SDXL_MODEL_LOADER,
SDXL_REFINER_LATENTS_TO_LATENTS,
SDXL_REFINER_MODEL_LOADER,
SDXL_REFINER_NEGATIVE_CONDITIONING,
SDXL_REFINER_POSITIVE_CONDITIONING,
} from './constants';
export const addSDXLRefinerToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { shouldUseSDXLRefiner, model } = state.generation;
const { positivePrompt, negativePrompt } = state.generation;
const {
refinerModel,
refinerAestheticScore,
positiveStylePrompt,
negativeStylePrompt,
refinerSteps,
refinerScheduler,
refinerCFGScale,
refinerStart,
} = state.sdxl;
if (!refinerModel) return;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!shouldUseSDXLRefiner) return;
// Unplug SDXL Latents Generation To Latents To Image
graph.edges = graph.edges.filter(
(e) =>
!(e.source.node_id === baseNodeId && ['latents'].includes(e.source.field))
);
// Unplug SDXL Text To Latents To Latents To Image
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === SDXL_TEXT_TO_LATENTS &&
['latents'].includes(e.source.field)
e.source.node_id === SDXL_MODEL_LOADER &&
['vae'].includes(e.source.field)
)
);
// graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
// id: SDXL_REFINER_MODEL_LOADER,
// type: 'sdxl_refiner_model_loader',
// };
graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
type: 'sdxl_refiner_model_loader',
id: SDXL_REFINER_MODEL_LOADER,
model: refinerModel,
};
graph.nodes[SDXL_REFINER_POSITIVE_CONDITIONING] = {
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_POSITIVE_CONDITIONING,
style: `${positivePrompt} ${positiveStylePrompt}`,
aesthetic_score: refinerAestheticScore,
};
graph.nodes[SDXL_REFINER_NEGATIVE_CONDITIONING] = {
type: 'sdxl_refiner_compel_prompt',
id: SDXL_REFINER_NEGATIVE_CONDITIONING,
style: `${negativePrompt} ${negativeStylePrompt}`,
aesthetic_score: refinerAestheticScore,
};
graph.nodes[SDXL_REFINER_LATENTS_TO_LATENTS] = {
type: 'l2l_sdxl',
id: SDXL_REFINER_LATENTS_TO_LATENTS,
cfg_scale: refinerCFGScale,
steps: refinerSteps / (1 - refinerStart),
scheduler: refinerScheduler,
denoising_start: refinerStart,
};
// graph.nodes[LATENTS_TO_IMAGE] = {
// type: 'l2i',
// id: LATENTS_TO_IMAGE,
// };
graph.edges.push(
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'unet',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'clip2',
},
destination: {
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
field: 'clip2',
},
},
{
source: {
node_id: SDXL_REFINER_POSITIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'positive_conditioning',
},
},
{
source: {
node_id: SDXL_REFINER_NEGATIVE_CONDITIONING,
field: 'conditioning',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'negative_conditioning',
},
},
{
source: {
node_id: baseNodeId,
field: 'latents',
},
destination: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'latents',
},
},
{
source: {
node_id: SDXL_REFINER_LATENTS_TO_LATENTS,
field: 'latents',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
}
);
};

View File

@ -27,6 +27,11 @@ export const SDXL_MODEL_LOADER = 'sdxl_model_loader';
export const SDXL_TEXT_TO_LATENTS = 't2l_sdxl';
export const SDXL_LATENTS_TO_LATENTS = 'l2l_sdxl';
export const SDXL_REFINER_MODEL_LOADER = 'sdxl_refiner_model_loader';
export const SDXL_REFINER_POSITIVE_CONDITIONING =
'sdxl_refiner_positive_conditioning';
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_LATENTS_TO_LATENTS = 'l2l_sdxl_refiner';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';