mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(ui): clean up base model handling in graph builder
This commit is contained in:
parent
9fb03d43ff
commit
b239891986
@ -39,6 +39,7 @@ import { assert } from 'tsafe';
|
|||||||
export const addGenerationTabControlLayers = async (
|
export const addGenerationTabControlLayers = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
|
base: BaseModelType,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
||||||
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||||
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||||
@ -51,12 +52,9 @@ export const addGenerationTabControlLayers = async (
|
|||||||
| Invocation<'main_model_loader'>
|
| Invocation<'main_model_loader'>
|
||||||
| Invocation<'sdxl_model_loader'>
|
| Invocation<'sdxl_model_loader'>
|
||||||
): Promise<Layer[]> => {
|
): Promise<Layer[]> => {
|
||||||
const mainModel = state.generation.model;
|
const isSDXL = base === 'sdxl';
|
||||||
assert(mainModel, 'Missing main model when building graph');
|
|
||||||
const isSDXL = mainModel.base === 'sdxl';
|
|
||||||
|
|
||||||
// Filter out layers with incompatible base model, missing control image
|
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, base));
|
||||||
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, mainModel.base));
|
|
||||||
|
|
||||||
const validControlAdapters = validLayers.filter(isControlAdapterLayer).map((l) => l.controlAdapter);
|
const validControlAdapters = validLayers.filter(isControlAdapterLayer).map((l) => l.controlAdapter);
|
||||||
for (const ca of validControlAdapters) {
|
for (const ca of validControlAdapters) {
|
||||||
@ -71,7 +69,7 @@ export const addGenerationTabControlLayers = async (
|
|||||||
const initialImageLayers = validLayers.filter(isInitialImageLayer);
|
const initialImageLayers = validLayers.filter(isInitialImageLayer);
|
||||||
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
|
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
|
||||||
if (initialImageLayers[0]) {
|
if (initialImageLayers[0]) {
|
||||||
addInitialImageLayerToGraph(state, g, denoise, noise, vaeSource, initialImageLayers[0]);
|
addInitialImageLayerToGraph(state, g, base, denoise, noise, vaeSource, initialImageLayers[0]);
|
||||||
}
|
}
|
||||||
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||||
// the existing conditioning nodes.
|
// the existing conditioning nodes.
|
||||||
@ -214,9 +212,7 @@ export const addGenerationTabControlLayers = async (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) =>
|
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
|
||||||
isValidIPAdapter(ipa, mainModel.base)
|
|
||||||
);
|
|
||||||
|
|
||||||
for (const ipAdapterConfig of validRegionalIPAdapters) {
|
for (const ipAdapterConfig of validRegionalIPAdapters) {
|
||||||
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
|
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
|
||||||
@ -428,6 +424,7 @@ const addGlobalIPAdapterToGraph = (
|
|||||||
const addInitialImageLayerToGraph = (
|
const addInitialImageLayerToGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
g: Graph,
|
g: Graph,
|
||||||
|
base: BaseModelType,
|
||||||
denoise: Invocation<'denoise_latents'>,
|
denoise: Invocation<'denoise_latents'>,
|
||||||
noise: Invocation<'noise'>,
|
noise: Invocation<'noise'>,
|
||||||
vaeSource:
|
vaeSource:
|
||||||
@ -437,13 +434,13 @@ const addInitialImageLayerToGraph = (
|
|||||||
| Invocation<'sdxl_model_loader'>,
|
| Invocation<'sdxl_model_loader'>,
|
||||||
layer: InitialImageLayer
|
layer: InitialImageLayer
|
||||||
) => {
|
) => {
|
||||||
const { vaePrecision, model } = state.generation;
|
const { vaePrecision } = state.generation;
|
||||||
const { refinerModel, refinerStart } = state.sdxl;
|
const { refinerModel, refinerStart } = state.sdxl;
|
||||||
const { width, height } = state.controlLayers.present.size;
|
const { width, height } = state.controlLayers.present.size;
|
||||||
assert(layer.isEnabled, 'Initial image layer is not enabled');
|
assert(layer.isEnabled, 'Initial image layer is not enabled');
|
||||||
assert(layer.image, 'Initial image layer has no image');
|
assert(layer.image, 'Initial image layer has no image');
|
||||||
|
|
||||||
const isSDXL = model?.base === 'sdxl';
|
const isSDXL = base === 'sdxl';
|
||||||
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
|
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
|
||||||
|
|
||||||
const { denoisingStrength } = layer;
|
const { denoisingStrength } = layer;
|
||||||
|
@ -126,6 +126,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
|||||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
assert(modelConfig.base === 'sd-1' || modelConfig.base === 'sd-2');
|
||||||
|
|
||||||
g.upsertMetadata({
|
g.upsertMetadata({
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
@ -155,6 +156,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
|||||||
const addedLayers = await addGenerationTabControlLayers(
|
const addedLayers = await addGenerationTabControlLayers(
|
||||||
state,
|
state,
|
||||||
g,
|
g,
|
||||||
|
modelConfig.base,
|
||||||
denoise,
|
denoise,
|
||||||
posCond,
|
posCond,
|
||||||
negCond,
|
negCond,
|
||||||
|
@ -117,6 +117,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||||
|
|
||||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
assert(modelConfig.base === 'sdxl');
|
||||||
|
|
||||||
g.upsertMetadata({
|
g.upsertMetadata({
|
||||||
generation_mode: 'txt2img',
|
generation_mode: 'txt2img',
|
||||||
@ -152,6 +153,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
|||||||
await addGenerationTabControlLayers(
|
await addGenerationTabControlLayers(
|
||||||
state,
|
state,
|
||||||
g,
|
g,
|
||||||
|
modelConfig.base,
|
||||||
denoise,
|
denoise,
|
||||||
posCond,
|
posCond,
|
||||||
negCond,
|
negCond,
|
||||||
|
Loading…
Reference in New Issue
Block a user