feat(ui): SDXL clip skip

Uses the same CLIP Skip value for both CLIP1 and CLIP2.

Adjusted SDXL CLIP Skip min/max/markers to be within the valid range (0 to 11).

Closes #4583
This commit is contained in:
psychedelicious 2024-05-16 15:55:22 +10:00 committed by Kent Keirsey
parent 3b1743b7c2
commit 40b4fa7238
4 changed files with 26 additions and 13 deletions

View File

@ -11,6 +11,8 @@ export const addSDXLLoRas = (
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,
modelLoader: Invocation<'sdxl_model_loader'>, modelLoader: Invocation<'sdxl_model_loader'>,
seamless: Invocation<'seamless'> | null, seamless: Invocation<'seamless'> | null,
clipSkip: Invocation<'clip_skip'>,
clipSkip2: Invocation<'clip_skip'>,
posCond: Invocation<'sdxl_compel_prompt'>, posCond: Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'sdxl_compel_prompt'> negCond: Invocation<'sdxl_compel_prompt'>
): void => { ): void => {
@ -37,8 +39,8 @@ export const addSDXLLoRas = (
g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
// Use seamless as UNet input if it exists, otherwise use the model loader // Use seamless as UNet input if it exists, otherwise use the model loader
g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet'); g.addEdge(seamless ?? modelLoader, 'unet', loraCollectionLoader, 'unet');
g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip'); g.addEdge(clipSkip, 'clip', loraCollectionLoader, 'clip');
g.addEdge(modelLoader, 'clip2', loraCollectionLoader, 'clip2'); g.addEdge(clipSkip2, 'clip', loraCollectionLoader, 'clip2');
// Reroute UNet & CLIP connections through the LoRA collection loader // Reroute UNet & CLIP connections through the LoRA collection loader
g.deleteEdgesTo(denoise, ['unet']); g.deleteEdgesTo(denoise, ['unet']);
g.deleteEdgesTo(posCond, ['clip', 'clip2']); g.deleteEdgesTo(posCond, ['clip', 'clip2']);

View File

@ -1,6 +1,7 @@
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { import {
CLIP_SKIP,
LATENTS_TO_IMAGE, LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING, NEGATIVE_CONDITIONING,
NEGATIVE_CONDITIONING_COLLECT, NEGATIVE_CONDITIONING_COLLECT,
@ -29,6 +30,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
model, model,
cfgScale: cfg_scale, cfgScale: cfg_scale,
cfgRescaleMultiplier: cfg_rescale_multiplier, cfgRescaleMultiplier: cfg_rescale_multiplier,
clipSkip: skipped_layers,
scheduler, scheduler,
seed, seed,
steps, steps,
@ -51,6 +53,16 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
id: SDXL_MODEL_LOADER, id: SDXL_MODEL_LOADER,
model, model,
}); });
const clipSkip = g.addNode({
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers,
});
const clipSkip2 = g.addNode({
type: 'clip_skip',
id: `${CLIP_SKIP}_2`,
skipped_layers,
});
const posCond = g.addNode({ const posCond = g.addNode({
type: 'sdxl_compel_prompt', type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING, id: POSITIVE_CONDITIONING,
@ -103,10 +115,12 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i; let imageOutput: Invocation<'l2i'> | Invocation<'img_nsfw'> | Invocation<'img_watermark'> = l2i;
g.addEdge(modelLoader, 'unet', denoise, 'unet'); g.addEdge(modelLoader, 'unet', denoise, 'unet');
g.addEdge(modelLoader, 'clip', posCond, 'clip'); g.addEdge(modelLoader, 'clip', clipSkip, 'clip');
g.addEdge(modelLoader, 'clip', negCond, 'clip'); g.addEdge(modelLoader, 'clip2', clipSkip2, 'clip');
g.addEdge(modelLoader, 'clip2', posCond, 'clip2'); g.addEdge(clipSkip, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 'clip2', negCond, 'clip2'); g.addEdge(clipSkip, 'clip', negCond, 'clip');
g.addEdge(clipSkip2, 'clip', posCond, 'clip2');
g.addEdge(clipSkip2, 'clip', negCond, 'clip2');
g.addEdge(posCond, 'conditioning', posCondCollect, 'item'); g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
g.addEdge(negCond, 'conditioning', negCondCollect, 'item'); g.addEdge(negCond, 'conditioning', negCondCollect, 'item');
g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning'); g.addEdge(posCondCollect, 'collection', denoise, 'positive_conditioning');
@ -132,12 +146,13 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
scheduler, scheduler,
positive_style_prompt: positiveStylePrompt, positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt, negative_style_prompt: negativeStylePrompt,
clip_skip: skipped_layers,
vae: vae ?? undefined, vae: vae ?? undefined,
}); });
const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader); const seamless = addSeamless(state, g, denoise, modelLoader, vaeLoader);
addSDXLLoRas(state, g, denoise, modelLoader, seamless, posCond, negCond); addSDXLLoRas(state, g, denoise, modelLoader, seamless, clipSkip, clipSkip2, posCond, negCond);
// We might get the VAE from the main model, custom VAE, or seamless node. // We might get the VAE from the main model, custom VAE, or seamless node.
const vaeSource = seamless ?? vaeLoader ?? modelLoader; const vaeSource = seamless ?? vaeLoader ?? modelLoader;

View File

@ -39,10 +39,6 @@ const ParamClipSkip = () => {
return CLIP_SKIP_MAP[model.base].markers; return CLIP_SKIP_MAP[model.base].markers;
}, [model]); }, [model]);
if (model?.base === 'sdxl') {
return null;
}
return ( return (
<FormControl> <FormControl>
<InformationalPopover feature="clipSkip"> <InformationalPopover feature="clipSkip">

View File

@ -39,8 +39,8 @@ export const CLIP_SKIP_MAP = {
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
}, },
sdxl: { sdxl: {
maxClip: 24, maxClip: 11,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24], markers: [0, 1, 2, 5, 11],
}, },
'sdxl-refiner': { 'sdxl-refiner': {
maxClip: 24, maxClip: 24,