feat: Add Custom VAE Support to Linear UI

This commit is contained in:
blessedcoolant 2023-07-01 12:10:35 +12:00 committed by psychedelicious
parent 7e18814dd0
commit 511978979e
7 changed files with 91 additions and 72 deletions

View File

@ -0,0 +1,68 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { modelIdToVAEModelField } from '../modelIdToVAEModelField';
import {
IMAGE_TO_LATENTS,
INPAINT,
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
VAE_LOADER,
} from './constants';
export const addVAEToGraph = (
graph: NonNullableGraph,
state: RootState
): void => {
const { vae: vaeId } = state.generation;
const vae_model = modelIdToVAEModelField(vaeId);
if (vaeId !== 'auto') {
graph.nodes[VAE_LOADER] = {
type: 'vae_loader',
id: VAE_LOADER,
vae_model,
};
}
if (
graph.id === 'text_to_image_graph' ||
graph.id === 'image_to_image_graph'
) {
graph.edges.push({
source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
});
}
if (graph.id === 'image_to_image_graph') {
graph.edges.push({
source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
});
}
if (graph.id === 'inpaint_graph') {
graph.edges.push({
source: {
node_id: vaeId === 'auto' ? MAIN_MODEL_LOADER : VAE_LOADER,
field: 'vae',
},
destination: {
node_id: INPAINT,
field: 'vae',
},
});
}
};

View File

@ -9,6 +9,7 @@ import {
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
@ -122,16 +123,6 @@ export const buildCanvasImageToImageGraph = (
field: 'clip',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: LATENTS_TO_LATENTS,
@ -162,16 +153,6 @@ export const buildCanvasImageToImageGraph = (
field: 'noise',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
@ -271,6 +252,9 @@ export const buildCanvasImageToImageGraph = (
});
}
// Add VAE
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);

View File

@ -8,6 +8,7 @@ import {
RangeOfSizeInvocation,
} from 'services/api/types';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addVAEToGraph } from './addVAEToGraph';
import {
INPAINT,
INPAINT_GRAPH,
@ -170,16 +171,6 @@ export const buildCanvasInpaintGraph = (
field: 'unet',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: INPAINT,
field: 'vae',
},
},
{
source: {
node_id: RANGE_OF_SIZE,
@ -203,6 +194,9 @@ export const buildCanvasInpaintGraph = (
],
};
// Add VAE
addVAEToGraph(graph, state);
// handle seed
if (shouldRandomizeSeed) {
// Random int node to generate the starting seed

View File

@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
@ -143,16 +144,6 @@ export const buildCanvasTextToImageGraph = (
field: 'latents',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: NOISE,
@ -166,6 +157,9 @@ export const buildCanvasTextToImageGraph = (
],
};
// Add VAE
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);

View File

@ -10,7 +10,10 @@ import {
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE,
IMAGE_TO_IMAGE_GRAPH,
IMAGE_TO_LATENTS,
LATENTS_TO_IMAGE,
@ -20,8 +23,6 @@ import {
NOISE,
POSITIVE_CONDITIONING,
RESIZE,
IMAGE_COLLECTION,
IMAGE_COLLECTION_ITERATE,
} from './constants';
const moduleLog = log.child({ namespace: 'nodes' });
@ -136,16 +137,6 @@ export const buildLinearImageToImageGraph = (
field: 'clip',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: LATENTS_TO_LATENTS,
@ -176,16 +167,7 @@ export const buildLinearImageToImageGraph = (
field: 'noise',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: IMAGE_TO_LATENTS,
field: 'vae',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
@ -322,6 +304,8 @@ export const buildLinearImageToImageGraph = (
},
});
}
// Add VAE
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);

View File

@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph';
import { modelIdToMainModelField } from '../modelIdToMainModelField';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addVAEToGraph } from './addVAEToGraph';
import {
LATENTS_TO_IMAGE,
MAIN_MODEL_LOADER,
@ -136,16 +137,6 @@ export const buildLinearTextToImageGraph = (
field: 'latents',
},
},
{
source: {
node_id: MAIN_MODEL_LOADER,
field: 'vae',
},
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'vae',
},
},
{
source: {
node_id: NOISE,
@ -159,6 +150,9 @@ export const buildLinearTextToImageGraph = (
],
};
// Add Custom VAE Support
addVAEToGraph(graph, state);
// add dynamic prompts, mutating `graph`
addDynamicPromptsToGraph(graph, state);

View File

@ -8,6 +8,7 @@ export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size';
export const ITERATE = 'iterate';
export const MAIN_MODEL_LOADER = 'main_model_loader';
export const VAE_LOADER = 'vae_loader';
export const IMAGE_TO_LATENTS = 'image_to_latents';
export const LATENTS_TO_LATENTS = 'latents_to_latents';
export const RESIZE = 'resize_image';