feat: Add Seamless to T2I / I2I / SDXL T2I / I2I + Refiner

This commit is contained in:
blessedcoolant 2023-08-29 04:01:04 +12:00
parent bb085c5fba
commit 594e547c3b
12 changed files with 290 additions and 33 deletions

View File

@ -1,11 +1,15 @@
import { RootState } from 'app/store/store';
import { MetadataAccumulatorInvocation } from 'services/api/types';
import {
MetadataAccumulatorInvocation,
SeamlessModeInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
CANVAS_OUTPUT,
LATENTS_TO_IMAGE,
MASK_BLUR,
METADATA_ACCUMULATOR,
REFINER_SEAMLESS,
SDXL_CANVAS_IMAGE_TO_IMAGE_GRAPH,
SDXL_CANVAS_INPAINT_GRAPH,
SDXL_CANVAS_OUTPAINT_GRAPH,
@ -21,7 +25,8 @@ import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
export const addSDXLRefinerToGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
baseNodeId: string,
modelLoaderNodeId?: string
): void => {
const {
refinerModel,
@ -33,6 +38,8 @@ export const addSDXLRefinerToGraph = (
refinerStart,
} = state.sdxl;
const { seamlessXAxis, seamlessYAxis } = state.generation;
if (!refinerModel) {
return;
}
@ -53,6 +60,10 @@ export const addSDXLRefinerToGraph = (
metadataAccumulator.refiner_steps = refinerSteps;
}
const modelLoaderId = modelLoaderNodeId
? modelLoaderNodeId
: SDXL_MODEL_LOADER;
// Construct Style Prompt
const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } =
craftSDXLStylePrompt(state, true);
@ -65,10 +76,7 @@ export const addSDXLRefinerToGraph = (
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === SDXL_MODEL_LOADER &&
['vae'].includes(e.source.field)
)
!(e.source.node_id === modelLoaderId && ['vae'].includes(e.source.field))
);
graph.nodes[SDXL_REFINER_MODEL_LOADER] = {
@ -98,8 +106,39 @@ export const addSDXLRefinerToGraph = (
denoising_end: 1,
};
graph.edges.push(
{
// Add Seamless To Refiner
if (seamlessXAxis || seamlessYAxis) {
graph.nodes[REFINER_SEAMLESS] = {
id: REFINER_SEAMLESS,
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
} as SeamlessModeInvocation;
graph.edges.push(
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'unet',
},
destination: {
node_id: REFINER_SEAMLESS,
field: 'unet',
},
},
{
source: {
node_id: REFINER_SEAMLESS,
field: 'unet',
},
destination: {
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'unet',
},
}
);
} else {
graph.edges.push({
source: {
node_id: SDXL_REFINER_MODEL_LOADER,
field: 'unet',
@ -108,7 +147,10 @@ export const addSDXLRefinerToGraph = (
node_id: SDXL_REFINER_DENOISE_LATENTS,
field: 'unet',
},
},
});
}
graph.edges.push(
{
source: {
node_id: SDXL_REFINER_MODEL_LOADER,

View File

@ -0,0 +1,79 @@
import { RootState } from 'app/store/store';
import { SeamlessModeInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import {
DENOISE_LATENTS,
IMAGE_TO_IMAGE_GRAPH,
SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS,
TEXT_TO_IMAGE_GRAPH,
} from './constants';
export const addSeamlessToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
modelLoaderNodeId: string
): void => {
// Remove Existing UNet Connections
const { seamlessXAxis, seamlessYAxis } = state.generation;
graph.nodes[SEAMLESS] = {
id: SEAMLESS,
type: 'seamless',
seamless_x: seamlessXAxis,
seamless_y: seamlessYAxis,
} as SeamlessModeInvocation;
graph.edges = graph.edges.filter(
(e) =>
!(
e.source.node_id === modelLoaderNodeId &&
['unet'].includes(e.source.field)
) &&
!(
e.source.node_id === modelLoaderNodeId &&
['vae'].includes(e.source.field)
)
);
if (
graph.id === TEXT_TO_IMAGE_GRAPH ||
graph.id === IMAGE_TO_IMAGE_GRAPH ||
graph.id === SDXL_TEXT_TO_IMAGE_GRAPH ||
graph.id === SDXL_IMAGE_TO_IMAGE_GRAPH
) {
graph.edges.push(
{
source: {
node_id: modelLoaderNodeId,
field: 'unet',
},
destination: {
node_id: SEAMLESS,
field: 'unet',
},
},
{
source: {
node_id: modelLoaderNodeId,
field: 'vae',
},
destination: {
node_id: SEAMLESS,
field: 'vae',
},
},
{
source: {
node_id: SEAMLESS,
field: 'unet',
},
destination: {
node_id: DENOISE_LATENTS,
field: 'unet',
},
}
);
}
};

View File

@ -10,6 +10,7 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -24,6 +25,7 @@ import {
NOISE,
POSITIVE_CONDITIONING,
RESIZE,
SEAMLESS,
} from './constants';
/**
@ -49,6 +51,8 @@ export const buildLinearImageToImageGraph = (
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
seamlessXAxis,
seamlessYAxis,
} = state.generation;
// TODO: add batch functionality
@ -80,6 +84,8 @@ export const buildLinearImageToImageGraph = (
throw new Error('No model found in state');
}
let modelLoaderNodeId = MAIN_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
@ -338,11 +344,17 @@ export const buildLinearImageToImageGraph = (
},
});
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// optionally add custom VAE
addVAEToGraph(state, graph, MAIN_MODEL_LOADER);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS);
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);

View File

@ -11,6 +11,7 @@ import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -20,10 +21,12 @@ import {
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
RESIZE,
SDXL_DENOISE_LATENTS,
SDXL_IMAGE_TO_IMAGE_GRAPH,
SDXL_MODEL_LOADER,
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -49,6 +52,8 @@ export const buildLinearSDXLImageToImageGraph = (
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
seamlessXAxis,
seamlessYAxis,
} = state.generation;
const {
@ -79,6 +84,9 @@ export const buildLinearSDXLImageToImageGraph = (
throw new Error('No model found in state');
}
// Model Loader ID
let modelLoaderNodeId = SDXL_MODEL_LOADER;
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
@ -351,15 +359,23 @@ export const buildLinearSDXLImageToImageGraph = (
},
});
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER);
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
modelLoaderNodeId = REFINER_SEAMLESS;
}
// optionally add custom VAE
addVAEToGraph(state, graph, SDXL_MODEL_LOADER);
addVAEToGraph(state, graph, modelLoaderNodeId);
// Add LoRA Support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);

View File

@ -7,6 +7,7 @@ import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -15,9 +16,11 @@ import {
NEGATIVE_CONDITIONING,
NOISE,
POSITIVE_CONDITIONING,
REFINER_SEAMLESS,
SDXL_DENOISE_LATENTS,
SDXL_MODEL_LOADER,
SDXL_TEXT_TO_IMAGE_GRAPH,
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
@ -38,6 +41,8 @@ export const buildLinearSDXLTextToImageGraph = (
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
seamlessXAxis,
seamlessYAxis,
} = state.generation;
const {
@ -61,6 +66,9 @@ export const buildLinearSDXLTextToImageGraph = (
const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } =
craftSDXLStylePrompt(state, shouldConcatSDXLStylePrompt);
// Model Loader ID
let modelLoaderNodeId = SDXL_MODEL_LOADER;
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
@ -244,16 +252,23 @@ export const buildLinearSDXLTextToImageGraph = (
},
});
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// Add Refiner if enabled
if (shouldUseSDXLRefiner) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
modelLoaderNodeId = REFINER_SEAMLESS;
}
// optionally add custom VAE
addVAEToGraph(state, graph, SDXL_MODEL_LOADER);
addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, SDXL_MODEL_LOADER);
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);

View File

@ -10,6 +10,7 @@ import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
@ -22,6 +23,7 @@ import {
NOISE,
ONNX_MODEL_LOADER,
POSITIVE_CONDITIONING,
SEAMLESS,
TEXT_TO_IMAGE_GRAPH,
} from './constants';
@ -42,6 +44,8 @@ export const buildLinearTextToImageGraph = (
shouldUseCpuNoise,
shouldUseNoiseSettings,
vaePrecision,
seamlessXAxis,
seamlessYAxis,
} = state.generation;
const use_cpu = shouldUseNoiseSettings
@ -55,7 +59,7 @@ export const buildLinearTextToImageGraph = (
const isUsingOnnxModel = model.model_type === 'onnx';
const modelLoaderNodeId = isUsingOnnxModel
let modelLoaderNodeId = isUsingOnnxModel
? ONNX_MODEL_LOADER
: MAIN_MODEL_LOADER;
@ -258,6 +262,12 @@ export const buildLinearTextToImageGraph = (
},
});
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
modelLoaderNodeId = SEAMLESS;
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);

View File

@ -56,6 +56,8 @@ export const SDXL_REFINER_POSITIVE_CONDITIONING =
export const SDXL_REFINER_NEGATIVE_CONDITIONING =
'sdxl_refiner_negative_conditioning';
export const SDXL_REFINER_DENOISE_LATENTS = 'sdxl_refiner_denoise_latents';
export const SEAMLESS = 'seamless';
export const REFINER_SEAMLESS = 'refiner_seamless';
// friendly graph ids
export const TEXT_TO_IMAGE_GRAPH = 'text_to_image_graph';

View File

@ -2,6 +2,7 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import { memo } from 'react';
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
import ParamSDXLRefinerCollapse from './ParamSDXLRefinerCollapse';
@ -17,6 +18,7 @@ const SDXLImageToImageTabParameters = () => {
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
<ParamSeamlessCollapse />
</>
);
};

View File

@ -2,6 +2,7 @@ import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/Para
import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse';
import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse';
import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse';
import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse';
import TextToImageTabCoreParameters from 'features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters';
import { memo } from 'react';
import ParamSDXLPromptArea from './ParamSDXLPromptArea';
@ -17,6 +18,7 @@ const SDXLTextToImageTabParameters = () => {
<ParamLoraCollapse />
<ParamDynamicPromptsCollapse />
<ParamNoiseCollapse />
<ParamSeamlessCollapse />
</>
);
};

View File

@ -9,7 +9,6 @@ export const initialConfigState: AppConfig = {
disabledFeatures: ['lightbox', 'faceRestore', 'batches'],
disabledSDFeatures: [
'variation',
'seamless',
'symmetry',
'hires',
'perlinNoise',

File diff suppressed because one or more lines are too long

View File

@ -130,6 +130,7 @@ export type ESRGANInvocation = s['ESRGANInvocation'];
export type DivideInvocation = s['DivideInvocation'];
export type ImageNSFWBlurInvocation = s['ImageNSFWBlurInvocation'];
export type ImageWatermarkInvocation = s['ImageWatermarkInvocation'];
export type SeamlessModeInvocation = s['SeamlessModeInvocation'];
// ControlNet Nodes
export type ControlNetInvocation = s['ControlNetInvocation'];