mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup, add loras
This commit is contained in:
parent
ea449f5a0a
commit
7668dc68a0
@ -1,11 +1,12 @@
|
|||||||
import { Graph, GraphType } from 'features/nodes/util/graph/generation/Graph';
|
import { Graph, GraphType } from 'features/nodes/util/graph/generation/Graph';
|
||||||
import { RootState } from '../../../../app/store/store';
|
import { RootState } from '../../../../app/store/store';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
import { ControlNetModelConfig, Invocation, NonNullableGraph } from '../../../../services/api/types';
|
import { CLIP_SKIP, CONTROL_NET_COLLECT, ESRGAN, IMAGE_TO_LATENTS, LATENTS_TO_IMAGE, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, NOISE, POSITIVE_CONDITIONING, RESIZE, SDXL_MODEL_LOADER, TILED_MULTI_DIFFUSION_DENOISE_LATENTS, UNSHARP_MASK, VAE_LOADER } from './constants';
|
||||||
import { ESRGAN, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
|
||||||
import { isParamESRGANModelName } from '../../../parameters/store/postprocessingSlice';
|
import { isParamESRGANModelName } from '../../../parameters/store/postprocessingSlice';
|
||||||
import { ControlNetConfig } from '../../../controlAdapters/store/types';
|
import { getSDXLStylePrompts } from './graphBuilderUtils';
|
||||||
import { MODEL_TYPES } from '../../types/constants';
|
import { addLoRAs } from './generation/addLoRAs';
|
||||||
|
import { addSDXLLoRas } from './generation/addSDXLLoRAs';
|
||||||
|
import { modelsApi } from '../../../../services/api/endpoints/models';
|
||||||
|
|
||||||
|
|
||||||
export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promise<GraphType> => {
|
export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promise<GraphType> => {
|
||||||
@ -24,63 +25,75 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
|
|||||||
assert(model, 'No model found in state');
|
assert(model, 'No model found in state');
|
||||||
assert(upscaleModel, 'No upscale model found in state');
|
assert(upscaleModel, 'No upscale model found in state');
|
||||||
assert(upscaleInitialImage, 'No initial image found in state');
|
assert(upscaleInitialImage, 'No initial image found in state');
|
||||||
|
assert(isParamESRGANModelName(upscaleModel.name), "")
|
||||||
if (!isParamESRGANModelName(upscaleModel.name)) {
|
|
||||||
throw new Error()
|
|
||||||
}
|
|
||||||
|
|
||||||
const g = new Graph()
|
const g = new Graph()
|
||||||
|
|
||||||
const unsharp_mask_1 = g.addNode({
|
const unsharpMaskNode1 = g.addNode({
|
||||||
id: 'unsharp_mask_1',
|
id: `${UNSHARP_MASK}_1`,
|
||||||
type: 'unsharp_mask',
|
type: 'unsharp_mask',
|
||||||
image: upscaleInitialImage,
|
image: upscaleInitialImage,
|
||||||
radius: 2,
|
radius: 2,
|
||||||
strength: ((sharpness + 10) * 3.75) + 25
|
strength: ((sharpness + 10) * 3.75) + 25
|
||||||
})
|
})
|
||||||
|
|
||||||
const esrgan = g.addNode({
|
const upscaleNode = g.addNode({
|
||||||
id: ESRGAN,
|
id: ESRGAN,
|
||||||
type: 'esrgan',
|
type: 'esrgan',
|
||||||
model_name: upscaleModel.name,
|
model_name: upscaleModel.name,
|
||||||
tile_size: 500
|
tile_size: 500
|
||||||
})
|
})
|
||||||
|
|
||||||
g.addEdge(unsharp_mask_1, 'image', esrgan, 'image')
|
g.addEdge(unsharpMaskNode1, 'image', upscaleNode, 'image')
|
||||||
|
|
||||||
const unsharp_mask_2 = g.addNode({
|
const unsharpMaskNode2 = g.addNode({
|
||||||
id: 'unsharp_mask_2',
|
id: `${UNSHARP_MASK}_2`,
|
||||||
type: 'unsharp_mask',
|
type: 'unsharp_mask',
|
||||||
radius: 2,
|
radius: 2,
|
||||||
strength: 50
|
strength: 50
|
||||||
})
|
})
|
||||||
|
|
||||||
g.addEdge(esrgan, 'image', unsharp_mask_2, 'image',)
|
g.addEdge(upscaleNode, 'image', unsharpMaskNode2, 'image',)
|
||||||
|
|
||||||
const SCALE = 2
|
const SCALE = 4
|
||||||
|
|
||||||
const resizeNode = g.addNode({
|
const resizeNode = g.addNode({
|
||||||
id: 'img_resize',
|
id: RESIZE,
|
||||||
type: 'img_resize',
|
type: 'img_resize',
|
||||||
width: upscaleInitialImage.width * SCALE, // TODO: handle floats
|
width: upscaleInitialImage.width * SCALE, // TODO: handle floats
|
||||||
height: upscaleInitialImage.height * SCALE, // TODO: handle floats
|
height: upscaleInitialImage.height * SCALE, // TODO: handle floats
|
||||||
resample_mode: "lanczos"
|
resample_mode: "lanczos",
|
||||||
|
is_intermediate: false
|
||||||
})
|
})
|
||||||
|
|
||||||
g.addEdge(unsharp_mask_2, 'image', resizeNode, "image")
|
g.addEdge(unsharpMaskNode2, 'image', resizeNode, "image")
|
||||||
|
|
||||||
|
const noiseNode = g.addNode({
|
||||||
|
id: NOISE,
|
||||||
|
type: "noise",
|
||||||
|
seed,
|
||||||
|
})
|
||||||
|
|
||||||
|
g.addEdge(resizeNode, 'width', noiseNode, "width")
|
||||||
|
g.addEdge(resizeNode, 'height', noiseNode, "height")
|
||||||
|
|
||||||
const sharpnessNode: Invocation<'unsharp_mask'> = { //before and after esrgan
|
const i2lNode = g.addNode({
|
||||||
id: 'unsharp_mask',
|
id: IMAGE_TO_LATENTS,
|
||||||
type: 'unsharp_mask',
|
type: "i2l",
|
||||||
image: upscaleInitialImage,
|
is_intermediate: false,
|
||||||
radius: 2,
|
fp32: vaePrecision === "fp32"
|
||||||
strength: ((sharpness + 10) * 3.75) + 25
|
})
|
||||||
};
|
|
||||||
|
|
||||||
const creativityNode: Invocation<'tiled_multi_diffusion_denoise_latents'> = { //before and after esrgan
|
g.addEdge(resizeNode, 'image', i2lNode, "image")
|
||||||
id: 'tiled_multi_diffusion_denoise_latents',
|
|
||||||
|
const l2iNode = g.addNode({
|
||||||
|
type: "l2i",
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
fp32: vaePrecision === "fp32"
|
||||||
|
})
|
||||||
|
|
||||||
|
const tiledMultidiffusionNode = g.addNode({
|
||||||
|
id: TILED_MULTI_DIFFUSION_DENOISE_LATENTS,
|
||||||
type: 'tiled_multi_diffusion_denoise_latents',
|
type: 'tiled_multi_diffusion_denoise_latents',
|
||||||
tile_height: 1024,
|
tile_height: 1024,
|
||||||
tile_width: 1024,
|
tile_width: 1024,
|
||||||
@ -90,45 +103,124 @@ export const buildMultidiffusionUpscsaleGraph = async (state: RootState): Promis
|
|||||||
scheduler,
|
scheduler,
|
||||||
denoising_start: (((creativity * -1) + 10) * 4.99) / 100,
|
denoising_start: (((creativity * -1) + 10) * 4.99) / 100,
|
||||||
denoising_end: 1
|
denoising_end: 1
|
||||||
};
|
});
|
||||||
|
|
||||||
const controlnetModel = {
|
const clipSkipNode = g.addNode({
|
||||||
key: "placeholder",
|
type: 'clip_skip',
|
||||||
hash: "placeholder",
|
id: CLIP_SKIP,
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
let posCondNode, negCondNode, modelNode;
|
||||||
|
|
||||||
|
if (model.base === "sdxl") {
|
||||||
|
const { positiveStylePrompt, negativeStylePrompt } = getSDXLStylePrompts(state);
|
||||||
|
|
||||||
|
posCondNode = g.addNode({
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
style: positiveStylePrompt
|
||||||
|
});
|
||||||
|
negCondNode = g.addNode({
|
||||||
|
type: 'sdxl_compel_prompt',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
style: negativeStylePrompt
|
||||||
|
});
|
||||||
|
modelNode = g.addNode({
|
||||||
|
type: 'sdxl_model_loader',
|
||||||
|
id: SDXL_MODEL_LOADER,
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
addSDXLLoRas(state, g, tiledMultidiffusionNode, modelNode, null, posCondNode, negCondNode);
|
||||||
|
} else {
|
||||||
|
posCondNode = g.addNode({
|
||||||
|
type: 'compel',
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
prompt: positivePrompt,
|
||||||
|
});
|
||||||
|
negCondNode = g.addNode({
|
||||||
|
type: 'compel',
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
prompt: negativePrompt,
|
||||||
|
});
|
||||||
|
modelNode = g.addNode({
|
||||||
|
type: 'main_model_loader',
|
||||||
|
id: MAIN_MODEL_LOADER,
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
addLoRAs(state, g, tiledMultidiffusionNode, modelNode, null, clipSkipNode, posCondNode, negCondNode);
|
||||||
|
}
|
||||||
|
|
||||||
|
g.addEdge(modelNode, 'clip', clipSkipNode, 'clip');
|
||||||
|
g.addEdge(clipSkipNode, 'clip', posCondNode, 'clip');
|
||||||
|
g.addEdge(clipSkipNode, 'clip', negCondNode, 'clip');
|
||||||
|
|
||||||
|
let vaeNode;
|
||||||
|
if (vae) {
|
||||||
|
vaeNode = g.addNode({
|
||||||
|
id: VAE_LOADER,
|
||||||
|
type: "vae_loader",
|
||||||
|
vae_model: vae
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
g.addEdge(vaeNode || modelNode, "vae", i2lNode, "vae")
|
||||||
|
g.addEdge(vaeNode || modelNode, "vae", l2iNode, "vae")
|
||||||
|
|
||||||
|
|
||||||
|
g.addEdge(noiseNode, "noise", tiledMultidiffusionNode, "noise")
|
||||||
|
g.addEdge(i2lNode, "latents", tiledMultidiffusionNode, "latents")
|
||||||
|
g.addEdge(posCondNode, 'conditioning', tiledMultidiffusionNode, 'positive_conditioning');
|
||||||
|
g.addEdge(negCondNode, 'conditioning', tiledMultidiffusionNode, 'negative_conditioning');
|
||||||
|
g.addEdge(modelNode, "unet", tiledMultidiffusionNode, "unet")
|
||||||
|
g.addEdge(tiledMultidiffusionNode, "latents", l2iNode, "latents")
|
||||||
|
|
||||||
|
|
||||||
|
const controlnetTileModel = { // TODO: figure out how to handle this, can't assume name is `tile` or that they have it installed
|
||||||
|
key: "",
|
||||||
|
hash: "",
|
||||||
type: "controlnet" as any,
|
type: "controlnet" as any,
|
||||||
name: "tile",
|
name: "tile",
|
||||||
base: model.base
|
base: model.base
|
||||||
}
|
}
|
||||||
|
|
||||||
const controlnet: Invocation<"controlnet"> = {
|
const controlnetNode1 = g.addNode({
|
||||||
id: "controlnet",
|
id: 'controlnet_1',
|
||||||
type: "controlnet",
|
type: "controlnet",
|
||||||
control_model: controlnetModel,
|
control_model: controlnetTileModel,
|
||||||
control_mode: "balanced",
|
control_mode: "balanced",
|
||||||
resize_mode: "just_resize",
|
resize_mode: "just_resize",
|
||||||
control_weight: ((((structure + 10) * 0.025) + 0.3) * 0.013) + 0.35
|
control_weight: ((((structure + 10) * 0.025) + 0.3) * 0.013) + 0.35,
|
||||||
}
|
begin_step_percent: 0,
|
||||||
|
end_step_percent: ((structure + 10) * 0.025) + 0.3
|
||||||
|
})
|
||||||
|
|
||||||
|
g.addEdge(resizeNode, "image", controlnetNode1, "image")
|
||||||
|
|
||||||
const noiseNode: Invocation<'noise'> = {
|
const controlnetNode2 = g.addNode({
|
||||||
id: "noise",
|
id: "controlnet_2",
|
||||||
type: "noise",
|
type: "controlnet",
|
||||||
seed,
|
control_model: controlnetTileModel,
|
||||||
// width: resized output width
|
control_mode: "balanced",
|
||||||
// height: resized output height
|
resize_mode: "just_resize",
|
||||||
}
|
control_weight: (((structure + 10) * 0.025) + 0.3) * 0.013,
|
||||||
|
begin_step_percent: ((structure + 10) * 0.025) + 0.3,
|
||||||
|
end_step_percent: 0.8
|
||||||
|
})
|
||||||
|
|
||||||
const posPrompt: Invocation<"compel"> = {
|
g.addEdge(resizeNode, "image", controlnetNode2, "image")
|
||||||
type: 'compel',
|
|
||||||
id: POSITIVE_CONDITIONING,
|
const collectNode = g.addNode({
|
||||||
prompt: positivePrompt,
|
id: CONTROL_NET_COLLECT,
|
||||||
}
|
type: "collect",
|
||||||
|
})
|
||||||
|
g.addEdge(controlnetNode1, "control", collectNode, "item")
|
||||||
|
g.addEdge(controlnetNode2, "control", collectNode, "item")
|
||||||
|
|
||||||
|
g.addEdge(collectNode, "collection", tiledMultidiffusionNode, "control")
|
||||||
|
|
||||||
const negPrompt: Invocation<"compel"> = {
|
|
||||||
type: 'compel',
|
|
||||||
id: NEGATIVE_CONDITIONING,
|
|
||||||
prompt: negativePrompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
return g.getGraph();
|
return g.getGraph();
|
||||||
|
|
||||||
|
@ -53,6 +53,8 @@ export const PROMPT_REGION_NEGATIVE_COND_PREFIX = 'prompt_region_negative_cond';
|
|||||||
export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted';
|
export const PROMPT_REGION_POSITIVE_COND_INVERTED_PREFIX = 'prompt_region_positive_cond_inverted';
|
||||||
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
export const POSITIVE_CONDITIONING_COLLECT = 'positive_conditioning_collect';
|
||||||
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
export const NEGATIVE_CONDITIONING_COLLECT = 'negative_conditioning_collect';
|
||||||
|
export const UNSHARP_MASK = 'unsharp_mask'
|
||||||
|
export const TILED_MULTI_DIFFUSION_DENOISE_LATENTS = "tiled_multi_diffusion_denoise_latents"
|
||||||
|
|
||||||
// friendly graph ids
|
// friendly graph ids
|
||||||
export const CONTROL_LAYERS_GRAPH = 'control_layers_graph';
|
export const CONTROL_LAYERS_GRAPH = 'control_layers_graph';
|
||||||
|
@ -8,7 +8,7 @@ import type { Invocation, S } from 'services/api/types';
|
|||||||
export const addLoRAs = (
|
export const addLoRAs = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>,
|
||||||
modelLoader: Invocation<'main_model_loader'>,
|
modelLoader: Invocation<'main_model_loader'>,
|
||||||
seamless: Invocation<'seamless'> | null,
|
seamless: Invocation<'seamless'> | null,
|
||||||
clipSkip: Invocation<'clip_skip'>,
|
clipSkip: Invocation<'clip_skip'>,
|
||||||
|
@ -8,7 +8,7 @@ import type { Invocation, S } from 'services/api/types';
|
|||||||
export const addSDXLLoRas = (
|
export const addSDXLLoRas = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'> | Invocation<'tiled_multi_diffusion_denoise_latents'>,
|
||||||
modelLoader: Invocation<'sdxl_model_loader'>,
|
modelLoader: Invocation<'sdxl_model_loader'>,
|
||||||
seamless: Invocation<'seamless'> | null,
|
seamless: Invocation<'seamless'> | null,
|
||||||
posCond: Invocation<'sdxl_compel_prompt'>,
|
posCond: Invocation<'sdxl_compel_prompt'>,
|
||||||
|
Loading…
Reference in New Issue
Block a user