mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(ui): clean up base model handling in graph builder
This commit is contained in:
parent
9fb03d43ff
commit
b239891986
@ -39,6 +39,7 @@ import { assert } from 'tsafe';
|
||||
export const addGenerationTabControlLayers = async (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
base: BaseModelType,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
|
||||
@ -51,12 +52,9 @@ export const addGenerationTabControlLayers = async (
|
||||
| Invocation<'main_model_loader'>
|
||||
| Invocation<'sdxl_model_loader'>
|
||||
): Promise<Layer[]> => {
|
||||
const mainModel = state.generation.model;
|
||||
assert(mainModel, 'Missing main model when building graph');
|
||||
const isSDXL = mainModel.base === 'sdxl';
|
||||
const isSDXL = base === 'sdxl';
|
||||
|
||||
// Filter out layers with incompatible base model, missing control image
|
||||
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, mainModel.base));
|
||||
const validLayers = state.controlLayers.present.layers.filter((l) => isValidLayer(l, base));
|
||||
|
||||
const validControlAdapters = validLayers.filter(isControlAdapterLayer).map((l) => l.controlAdapter);
|
||||
for (const ca of validControlAdapters) {
|
||||
@ -71,7 +69,7 @@ export const addGenerationTabControlLayers = async (
|
||||
const initialImageLayers = validLayers.filter(isInitialImageLayer);
|
||||
assert(initialImageLayers.length <= 1, 'Only one initial image layer allowed');
|
||||
if (initialImageLayers[0]) {
|
||||
addInitialImageLayerToGraph(state, g, denoise, noise, vaeSource, initialImageLayers[0]);
|
||||
addInitialImageLayerToGraph(state, g, base, denoise, noise, vaeSource, initialImageLayers[0]);
|
||||
}
|
||||
// TODO: We should probably just use conditioning collectors by default, and skip all this fanagling with re-routing
|
||||
// the existing conditioning nodes.
|
||||
@ -214,9 +212,7 @@ export const addGenerationTabControlLayers = async (
|
||||
}
|
||||
}
|
||||
|
||||
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) =>
|
||||
isValidIPAdapter(ipa, mainModel.base)
|
||||
);
|
||||
const validRegionalIPAdapters: IPAdapterConfigV2[] = layer.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
|
||||
|
||||
for (const ipAdapterConfig of validRegionalIPAdapters) {
|
||||
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
|
||||
@ -428,6 +424,7 @@ const addGlobalIPAdapterToGraph = (
|
||||
const addInitialImageLayerToGraph = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
base: BaseModelType,
|
||||
denoise: Invocation<'denoise_latents'>,
|
||||
noise: Invocation<'noise'>,
|
||||
vaeSource:
|
||||
@ -437,13 +434,13 @@ const addInitialImageLayerToGraph = (
|
||||
| Invocation<'sdxl_model_loader'>,
|
||||
layer: InitialImageLayer
|
||||
) => {
|
||||
const { vaePrecision, model } = state.generation;
|
||||
const { vaePrecision } = state.generation;
|
||||
const { refinerModel, refinerStart } = state.sdxl;
|
||||
const { width, height } = state.controlLayers.present.size;
|
||||
assert(layer.isEnabled, 'Initial image layer is not enabled');
|
||||
assert(layer.image, 'Initial image layer has no image');
|
||||
|
||||
const isSDXL = model?.base === 'sdxl';
|
||||
const isSDXL = base === 'sdxl';
|
||||
const useRefinerStartEnd = isSDXL && Boolean(refinerModel);
|
||||
|
||||
const { denoisingStrength } = layer;
|
||||
|
@ -126,6 +126,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
assert(modelConfig.base === 'sd-1' || modelConfig.base === 'sd-2');
|
||||
|
||||
g.upsertMetadata({
|
||||
generation_mode: 'txt2img',
|
||||
@ -155,6 +156,7 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<GraphTy
|
||||
const addedLayers = await addGenerationTabControlLayers(
|
||||
state,
|
||||
g,
|
||||
modelConfig.base,
|
||||
denoise,
|
||||
posCond,
|
||||
negCond,
|
||||
|
@ -117,6 +117,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
assert(modelConfig.base === 'sdxl');
|
||||
|
||||
g.upsertMetadata({
|
||||
generation_mode: 'txt2img',
|
||||
@ -152,6 +153,7 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
|
||||
await addGenerationTabControlLayers(
|
||||
state,
|
||||
g,
|
||||
modelConfig.base,
|
||||
denoise,
|
||||
posCond,
|
||||
negCond,
|
||||
|
Loading…
Reference in New Issue
Block a user