tidy(ui): clean up base model handling in graph builder

This commit is contained in:
psychedelicious 2024-05-14 14:21:52 +10:00
parent 9fb03d43ff
commit b239891986
3 changed files with 12 additions and 11 deletions

View File

@ -39,6 +39,7 @@ import { assert } from 'tsafe';
export const addGenerationTabControlLayers = async (
state: RootState,
g: Graph,
base: BaseModelType,
denoise: Invocation<'denoise_latents'>,
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
@ -51,12 +52,9 @@ export const addGenerationTabControlLayers = async (
| Invocation<'main_model_loader'>
| Invocation<'sdxl_model_loader'>
): Promise<Layer[]> => {
const mainModel = state.generation.model;
assert(mainModel, 'Missing main model when building graph');
const isSDXL = mainModel.base === 'sdxl';
const isSDXL = base === 'sdxl';
// Filter out layers with incompatible base model, missing control image
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, mainModel.base));
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, base));
const validControlAdapters = validLayers.filter(isControlAdapterLayer).map((l) => l.controlAdapter);
for (const ca of validControlAdapters) {
@ -71,7 +69,7 @@ export const addGenerationTabControlLayers = async (
const initialImageLayers = validLayers.filter(isInitialImageLayer);
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
if (initialImageLayers[0]) {
addInitialImageLayerToGraph(state, g, denoise, noise, vaeSource, initialImageLayers[0]);
addInitialImageLayerToGraph(state, g, base, denoise, noise, vaeSource, initialImageLayers[0]);
}
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
// the existing conditioning nodes.
@ -214,9 +212,7 @@ export const addGenerationTabControlLayers = async (
}
}
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) =>
isValidIPAdapter(ipa, mainModel.base)
);
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
for (const ipAdapterConfig of validRegionalIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
@ -428,6 +424,7 @@ const addGlobalIPAdapterToGraph = (
const addInitialImageLayerToGraph = (
state: RootState,
g: Graph,
base: BaseModelType,
denoise: Invocation<'denoise_latents'>,
noise: Invocation<'noise'>,
vaeSource:
@ -437,13 +434,13 @@ const addInitialImageLayerToGraph = (
| Invocation<'sdxl_model_loader'>,
layer: InitialImageLayer
) => {
const { vaePrecision, model } = state.generation;
const { vaePrecision } = state.generation;
const { refinerModel, refinerStart } = state.sdxl;
const { width, height } = state.controlLayers.present.size;
assert(layer.isEnabled, 'Initial image layer is not enabled');
assert(layer.image, 'Initial image layer has no image');
const isSDXL = model?.base === 'sdxl';
const isSDXL = base === 'sdxl';
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
const { denoisingStrength } = layer;

View File

@ -126,6 +126,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'sd-1' || modelConfig.base === 'sd-2');
g.upsertMetadata({
generation_mode: 'txt2img',
@ -155,6 +156,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
const addedLayers = await addGenerationTabControlLayers(
state,
g,
modelConfig.base,
denoise,
posCond,
negCond,

View File

@ -117,6 +117,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'sdxl');
g.upsertMetadata({
generation_mode: 'txt2img',
@ -152,6 +153,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
await addGenerationTabControlLayers(
state,
g,
modelConfig.base,
denoise,
posCond,
negCond,