cleanup, add loras

This commit is contained in:
Mary Hipp 2024-07-18 15:28:02 -04:00 committed by psychedelicious
parent ea449f5a0a
commit 7668dc68a0
4 changed files with 150 additions and 56 deletions

View File

@ -1,11 +1,12 @@
import { Graph, GraphType } from 'features/nodes/util/graph/generation/Graph'; import { Graph, GraphType } from 'features/nodes/util/graph/generation/Graph';
import { RootState } from '../../../../app/store/store'; import { RootState } from '../../../../app/store/store';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { ControlNetModelConfig, Invocation, NonNullableGraph } from '../../../../services/api/types'; import { CLIP_SKIP, CONTROL_NET_COLLECT, ESRGAN, IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, RESIZE, SDXL_MODEL_LOADER, TILED_MULTI_DIFFUSION_DENOISE_LATENTS, UNSHARP_MASK, VAE_LOADER } from './constants';
import { ESRGAN, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
import { isParamESRGANModelName } from '../../../parameters/store/postprocessingSlice'; import { isParamESRGANModelName } from '../../../parameters/store/postprocessingSlice';
import { ControlNetConfig } from '../../../controlAdapters/store/types'; import { getSDXLStylePrompts } from './graphBuilderUtils';
import { MODEL_TYPES } from '../../types/constants'; import { addLoRAs } from './generation/addLoRAs';
import { addSDXLLoRas } from './generation/addSDXLLoRAs';
import { modelsApi } from '../../../../services/api/endpoints/models';
export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promise<GraphType> => { export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promise<GraphType> => {
@ -24,63 +25,75 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
assert(model, 'No model found in state'); assert(model, 'No model found in state');
assert(upscaleModel, 'No upscale model found in state'); assert(upscaleModel, 'No upscale model found in state');
assert(upscaleInitialImage, 'No initial image found in state'); assert(upscaleInitialImage, 'No initial image found in state');
assert(isParamESRGANModelName(upscaleModel.name), "")
if (!isParamESRGANModelName(upscaleModel.name)) {
throw new Error()
}
const g = new Graph() const g = new Graph()
const unsharp_mask_1 = g.addNode({ const unsharpMaskNode1 = g.addNode({
id: 'unsharp_mask_1', id: `${UNSHARP_MASK}_1`,
type: 'unsharp_mask', type: 'unsharp_mask',
image: upscaleInitialImage, image: upscaleInitialImage,
radius: 2, radius: 2,
strength: ((sharpness + 10) * 3.75) + 25 strength: ((sharpness + 10) * 3.75) + 25
}) })
const esrgan = g.addNode({ const upscaleNode = g.addNode({
id: ESRGAN, id: ESRGAN,
type: 'esrgan', type: 'esrgan',
model_name: upscaleModel.name, model_name: upscaleModel.name,
tile_size: 500 tile_size: 500
}) })
g.addEdge(unsharp_mask_1, 'image', esrgan, 'image') g.addEdge(unsharpMaskNode1, 'image', upscaleNode, 'image')
const unsharp_mask_2 = g.addNode({ const unsharpMaskNode2 = g.addNode({
id: 'unsharp_mask_2', id: `${UNSHARP_MASK}_2`,
type: 'unsharp_mask', type: 'unsharp_mask',
radius: 2, radius: 2,
strength: 50 strength: 50
}) })
g.addEdge(esrgan, 'image', unsharp_mask_2, 'image',) g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image',)
const SCALE = 2 const SCALE = 4
const resizeNode = g.addNode({ const resizeNode = g.addNode({
id: 'img_resize', id: RESIZE,
type: 'img_resize', type: 'img_resize',
width: upscaleInitialImage.width * SCALE, // TODO: handle floats width: upscaleInitialImage.width * SCALE, // TODO: handle floats
height: upscaleInitialImage.height * SCALE, // TODO: handle floats height: upscaleInitialImage.height * SCALE, // TODO: handle floats
resample_mode: "lanczos" resample_mode: "lanczos",
is_intermediate: false
}) })
g.addEdge(unsharp_mask_2, 'image', resizeNode, "image") g.addEdge(unsharpMaskNode2, 'image', resizeNode, "image")
const noiseNode = g.addNode({
id: NOISE,
type: "noise",
seed,
})
g.addEdge(resizeNode, 'width', noiseNode, "width")
g.addEdge(resizeNode, 'height', noiseNode, "height")
const sharpnessNode: Invocation<'unsharp_mask'> = { //before and after esrgan const i2lNode = g.addNode({
id: 'unsharp_mask', id: IMAGE_TO_LATENTS,
type: 'unsharp_mask', type: "i2l",
image: upscaleInitialImage, is_intermediate: false,
radius: 2, fp32: vaePrecision === "fp32"
strength: ((sharpness + 10) * 3.75) + 25 })
};
const creativityNode: Invocation<'tiled_multi_diffusion_denoise_latents'> = { //before and after esrgan g.addEdge(resizeNode, 'image', i2lNode, "image")
id: 'tiled_multi_diffusion_denoise_latents',
const l2iNode = g.addNode({
type: "l2i",
id: LATENTS_TO_IMAGE,
fp32: vaePrecision === "fp32"
})
const tiledMultidiffusionNode = g.addNode({
id: TILED_MULTI_DIFFUSION_DENOISE_LATENTS,
type: 'tiled_multi_diffusion_denoise_latents', type: 'tiled_multi_diffusion_denoise_latents',
tile_height: 1024, tile_height: 1024,
tile_width: 1024, tile_width: 1024,
@ -90,45 +103,124 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
scheduler, scheduler,
denoising_start: (((creativity * -1) + 10) * 4.99) / 100, denoising_start: (((creativity * -1) + 10) * 4.99) / 100,
denoising_end: 1 denoising_end: 1
}; });
const controlnetModel = { const clipSkipNode = g.addNode({
key: "placeholder", type: 'clip_skip',
hash: "placeholder", id: CLIP_SKIP,
});
let posCondNode, negCondNode, modelNode;
if (model.base === "sdxl") {
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
posCondNode = g.addNode({
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: positiveStylePrompt
});
negCondNode = g.addNode({
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: negativeStylePrompt
});
modelNode = g.addNode({
type: 'sdxl_model_loader',
id: SDXL_MODEL_LOADER,
model,
});
addSDXLLoRas(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode);
} else {
posCondNode = g.addNode({
type: 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
});
negCondNode = g.addNode({
type: 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
});
modelNode = g.addNode({
type: 'main_model_loader',
id: MAIN_MODEL_LOADER,
model,
});
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode);
}
g.addEdge(modelNode, 'clip', clipSkipNode, 'clip');
g.addEdge(clipSkipNode, 'clip', posCondNode, 'clip');
g.addEdge(clipSkipNode, 'clip', negCondNode, 'clip');
let vaeNode;
if (vae) {
vaeNode = g.addNode({
id: VAE_LOADER,
type: "vae_loader",
vae_model: vae
})
}
g.addEdge(vaeNode || modelNode, "vae", i2lNode, "vae")
g.addEdge(vaeNode || modelNode, "vae", l2iNode, "vae")
g.addEdge(noiseNode, "noise", tiledMultidiffusionNode, "noise")
g.addEdge(i2lNode, "latents", tiledMultidiffusionNode, "latents")
g.addEdge(posCondNode, 'conditioning', tiledMultidiffusionNode, 'positive_conditioning');
g.addEdge(negCondNode, 'conditioning', tiledMultidiffusionNode, 'negative_conditioning');
g.addEdge(modelNode, "unet", tiledMultidiffusionNode, "unet")
g.addEdge(tiledMultidiffusionNode, "latents", l2iNode, "latents")
const controlnetTileModel = { // TODO: figure out how to handle this, can't assume name is `tile` or that they have it installed
key: "",
hash: "",
type: "controlnet" as any, type: "controlnet" as any,
name: "tile", name: "tile",
base: model.base base: model.base
} }
const controlnet: Invocation<"controlnet"> = { const controlnetNode1 = g.addNode({
id: "controlnet", id: 'controlnet_1',
type: "controlnet", type: "controlnet",
control_model: controlnetModel, control_model: controlnetTileModel,
control_mode: "balanced", control_mode: "balanced",
resize_mode: "just_resize", resize_mode: "just_resize",
control_weight: ((((structure + 10) * 0.025) + 0.3) * 0.013) + 0.35 control_weight: ((((structure + 10) * 0.025) + 0.3) * 0.013) + 0.35,
} begin_step_percent: 0,
end_step_percent: ((structure + 10) * 0.025) + 0.3
})
g.addEdge(resizeNode, "image", controlnetNode1, "image")
const noiseNode: Invocation<'noise'> = { const controlnetNode2 = g.addNode({
id: "noise", id: "controlnet_2",
type: "noise", type: "controlnet",
seed, control_model: controlnetTileModel,
// width: resized output width control_mode: "balanced",
// height: resized output height resize_mode: "just_resize",
} control_weight: (((structure + 10) * 0.025) + 0.3) * 0.013,
begin_step_percent: ((structure + 10) * 0.025) + 0.3,
end_step_percent: 0.8
})
const posPrompt: Invocation<"compel"> = { g.addEdge(resizeNode, "image", controlnetNode2, "image")
type: 'compel',
id: POSITIVE_CONDITIONING, const collectNode = g.addNode({
prompt: positivePrompt, id: CONTROL_NET_COLLECT,
} type: "collect",
})
g.addEdge(controlnetNode1, "control", collectNode, "item")
g.addEdge(controlnetNode2, "control", collectNode, "item")
g.addEdge(collectNode, "collection", tiledMultidiffusionNode, "control")
const negPrompt: Invocation<"compel"> = {
type: 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
}
return g.getGraph(); return g.getGraph();

View File

@ -53,6 +53,8 @@ export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted'; export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted';
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect'; export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect'; export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
export const UNSHARP_MASK = 'unsharp_mask'
export const TILED_MULTI_DIFFUSION_DENOISE_LATENTS = "tiled_multi_diffusion_denoise_latents"
// friendly graph ids // friendly graph ids
export const CONTROL_LAYERS_GRAPH = 'control_layers_graph'; export const CONTROL_LAYERS_GRAPH = 'control_layers_graph';

View File

@ -8,7 +8,7 @@ import type { Invocation, S } from 'services/api/types';
export const addLoRAs = ( export const addLoRAs = (
state: RootState, state: RootState,
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>,
modelLoader: Invocation<'main_model_loader'>, modelLoader: Invocation<'main_model_loader'>,
seamless: Invocation<'seamless'> | null, seamless: Invocation<'seamless'> | null,
clipSkip: Invocation<'clip_skip'>, clipSkip: Invocation<'clip_skip'>,

View File

@ -8,7 +8,7 @@ import type { Invocation, S } from 'services/api/types';
export const addSDXLLoRas = ( export const addSDXLLoRas = (
state: RootState, state: RootState,
g: Graph, g: Graph,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>,
modelLoader: Invocation<'sdxl_model_loader'>, modelLoader: Invocation<'sdxl_model_loader'>,
seamless: Invocation<'seamless'> | null, seamless: Invocation<'seamless'> | null,
posCond: Invocation<'sdxl_compel_prompt'>, posCond: Invocation<'sdxl_compel_prompt'>,