tidy(ui): clean up base model handling in graph builder

This commit is contained in:
psychedelicious 2024-05-14 14:21:52 +10:00
parent 9fb03d43ff
commit b239891986
3 changed files with 12 additions and 11 deletions

View File

@ -39,6 +39,7 @@ import { assert } from 'tsafe';
export const addGenerationTabControlLayers = async ( export const addGenerationTabControlLayers = async (
state: RootState, state: RootState,
g: Graph, g: Graph,
base: BaseModelType,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
@ -51,12 +52,9 @@ export const addGenerationTabControlLayers = async (
| Invocation<'main_model_loader'> | Invocation<'main_model_loader'>
| Invocation<'sdxl_model_loader'> | Invocation<'sdxl_model_loader'>
): Promise<Layer[]> => { ): Promise<Layer[]> => {
const mainModel = state.generation.model; const isSDXL = base === 'sdxl';
assert(mainModel, 'Missing main model when building graph');
const isSDXL = mainModel.base === 'sdxl';
// Filter out layers with incompatible base model, missing control image const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, base));
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, mainModel.base));
const validControlAdapters = validLayers.filter(isControlAdapterLayer).map((l) => l.controlAdapter); const validControlAdapters = validLayers.filter(isControlAdapterLayer).map((l) => l.controlAdapter);
for (const ca of validControlAdapters) { for (const ca of validControlAdapters) {
@ -71,7 +69,7 @@ export const addGenerationTabControlLayers = async (
const initialImageLayers = validLayers.filter(isInitialImageLayer); const initialImageLayers = validLayers.filter(isInitialImageLayer);
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed'); assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
if (initialImageLayers[0]) { if (initialImageLayers[0]) {
addInitialImageLayerToGraph(state, g, denoise, noise, vaeSource, initialImageLayers[0]); addInitialImageLayerToGraph(state, g, base, denoise, noise, vaeSource, initialImageLayers[0]);
} }
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing // TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
// the existing conditioning nodes. // the existing conditioning nodes.
@ -214,9 +212,7 @@ export const addGenerationTabControlLayers = async (
} }
} }
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) => const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
isValidIPAdapter(ipa, mainModel.base)
);
for (const ipAdapterConfig of validRegionalIPAdapters) { for (const ipAdapterConfig of validRegionalIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise); const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
@ -428,6 +424,7 @@ const addGlobalIPAdapterToGraph = (
const addInitialImageLayerToGraph = ( const addInitialImageLayerToGraph = (
state: RootState, state: RootState,
g: Graph, g: Graph,
base: BaseModelType,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,
noise: Invocation<'noise'>, noise: Invocation<'noise'>,
vaeSource: vaeSource:
@ -437,13 +434,13 @@ const addInitialImageLayerToGraph = (
| Invocation<'sdxl_model_loader'>, | Invocation<'sdxl_model_loader'>,
layer: InitialImageLayer layer: InitialImageLayer
) => { ) => {
const { vaePrecision, model } = state.generation; const { vaePrecision } = state.generation;
const { refinerModel, refinerStart } = state.sdxl; const { refinerModel, refinerStart } = state.sdxl;
const { width, height } = state.controlLayers.present.size; const { width, height } = state.controlLayers.present.size;
assert(layer.isEnabled, 'Initial image layer is not enabled'); assert(layer.isEnabled, 'Initial image layer is not enabled');
assert(layer.image, 'Initial image layer has no image'); assert(layer.image, 'Initial image layer has no image');
const isSDXL = model?.base === 'sdxl'; const isSDXL = base === 'sdxl';
const useRefinerStartEnd = isSDXL && Boolean(refinerModel); const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
const { denoisingStrength } = layer; const { denoisingStrength } = layer;

View File

@ -126,6 +126,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
g.addEdge(denoise, 'latents', l2i, 'latents'); g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'sd-1' || modelConfig.base === 'sd-2');
g.upsertMetadata({ g.upsertMetadata({
generation_mode: 'txt2img', generation_mode: 'txt2img',
@ -155,6 +156,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
const addedLayers = await addGenerationTabControlLayers( const addedLayers = await addGenerationTabControlLayers(
state, state,
g, g,
modelConfig.base,
denoise, denoise,
posCond, posCond,
negCond, negCond,

View File

@ -117,6 +117,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
g.addEdge(denoise, 'latents', l2i, 'latents'); g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'sdxl');
g.upsertMetadata({ g.upsertMetadata({
generation_mode: 'txt2img', generation_mode: 'txt2img',
@ -152,6 +153,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
await addGenerationTabControlLayers( await addGenerationTabControlLayers(
state, state,
g, g,
modelConfig.base,
denoise, denoise,
posCond, posCond,
negCond, negCond,