fix(ui): fix metadata for graphs to use new enriched format

This commit is contained in:
psychedelicious 2024-03-06 19:38:38 +11:00
parent d4a2ea68fc
commit c57f6ee885
25 changed files with 317 additions and 176 deletions

View File

@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis
).unwrap();
}
const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
log.debug({ graph: parseify(graph) }, `Canvas graph built`);

View File

@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
if (model && model.base === 'sdxl') {
if (action.payload.tabName === 'txt2img') {
graph = buildLinearSDXLTextToImageGraph(state);
graph = await buildLinearSDXLTextToImageGraph(state);
} else {
graph = buildLinearSDXLImageToImageGraph(state);
graph = await buildLinearSDXLImageToImageGraph(state);
}
} else {
if (action.payload.tabName === 'txt2img') {
graph = buildLinearTextToImageGraph(state);
graph = await buildLinearTextToImageGraph(state);
} else {
graph = buildLinearImageToImageGraph(state);
graph = await buildLinearImageToImageGraph(state);
}
}

View File

@ -55,8 +55,22 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zSubModelType = z.enum([
'unet',
'text_encoder',
'text_encoder_2',
'tokenizer',
'tokenizer_2',
'vae',
'vae_decoder',
'vae_encoder',
'scheduler',
'safety_checker',
]);
const zModelIdentifier = z.object({
key: z.string().min(1),
submodel_type: zSubModelType.nullish(),
});
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
zModelIdentifier.safeParse(field).success;

View File

@ -1,18 +1,22 @@
import type { RootState } from 'app/store/store';
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type {
CollectInvocation,
ControlField,
ControlNetInvocation,
CoreMetadataInvocation,
NonNullableGraph,
} from 'services/api/types';
import { isControlNetModelConfig } from 'services/api/types';
import { CONTROL_NET_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
export const addControlNetToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
@ -39,7 +43,7 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
},
});
validControlNets.forEach((controlNet) => {
validControlNets.forEach(async (controlNet) => {
if (!controlNet.model) {
return;
}
@ -85,7 +89,17 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
controlNetMetadata.push(omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField);
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig);
controlNetMetadata.push({
control_model: getModelMetadataField(modelConfig),
control_weight: weight,
control_mode: controlMode,
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
image: controlNetNode.image,
});
graph.edges.push({
source: { node_id: controlNetNode.id, field: 'control' },

View File

@ -1,18 +1,22 @@
import type { RootState } from 'app/store/store';
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type {
CollectInvocation,
CoreMetadataInvocation,
IPAdapterInvocation,
IPAdapterMetadataField,
NonNullableGraph,
} from 'services/api/types';
import { isIPAdapterModelConfig } from 'services/api/types';
import { IP_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
export const addIPAdapterToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
@ -35,7 +39,7 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
validIPAdapters.forEach((ipAdapter) => {
validIPAdapters.forEach(async (ipAdapter) => {
if (!ipAdapter.model) {
return;
}
@ -58,9 +62,17 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
ipAdapterMetdata.push(omit(ipAdapterNode, ['id', 'type', 'is_intermediate']) as IPAdapterMetadataField);
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig);
ipAdapterMetdata.push({
weight: weight,
ip_adapter_model: getModelMetadataField(modelConfig),
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
image: ipAdapterNode.image,
});
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },

View File

@ -1,16 +1,22 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
import {
type CoreMetadataInvocation,
isLoRAModelConfig,
type LoraLoaderInvocation,
type NonNullableGraph,
} from 'services/api/types';
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addLoRAsToGraph = (
export const addLoRAsToGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId: string = MAIN_MODEL_LOADER
): void => {
): Promise<void> => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
@ -39,7 +45,7 @@ export const addLoRAsToGraph = (
let currentLoraIndex = 0;
const loraMetadata: CoreMetadataInvocation['loras'] = [];
enabledLoRAs.forEach((lora) => {
enabledLoRAs.forEach(async (lora) => {
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
@ -52,8 +58,10 @@ export const addLoRAsToGraph = (
weight,
};
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
loraMetadata.push({
model: { key },
model: getModelMetadataField(modelConfig),
weight,
});

View File

@ -1,6 +1,12 @@
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { filter, size } from 'lodash-es';
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types';
import {
type CoreMetadataInvocation,
isLoRAModelConfig,
type NonNullableGraph,
type SDXLLoraLoaderInvocation,
} from 'services/api/types';
import {
LORA_LOADER,
@ -10,14 +16,14 @@ import {
SDXL_REFINER_INPAINT_CREATE_MASK,
SEAMLESS,
} from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addSDXLLoRAsToGraph = (
export const addSDXLLoRAsToGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId: string = SDXL_MODEL_LOADER
): void => {
): Promise<void> => {
/**
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
@ -55,7 +61,7 @@ export const addSDXLLoRAsToGraph = (
let lastLoraNodeId = '';
let currentLoraIndex = 0;
enabledLoRAs.forEach((lora) => {
enabledLoRAs.forEach(async (lora) => {
const { weight } = lora;
const { key } = lora.model;
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
@ -68,7 +74,9 @@ export const addSDXLLoRAsToGraph = (
weight,
};
loraMetadata.push({ model: { key }, weight });
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
// add to graph
graph.nodes[currentLoraNodeId] = loraLoaderNode;

View File

@ -1,9 +1,11 @@
import type { RootState } from 'app/store/store';
import type {
CreateDenoiseMaskInvocation,
ImageDTO,
NonNullableGraph,
SeamlessModeInvocation,
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type CreateDenoiseMaskInvocation,
type ImageDTO,
isRefinerMainModelModelConfig,
type NonNullableGraph,
type SeamlessModeInvocation,
} from 'services/api/types';
import {
@ -25,16 +27,16 @@ import {
SDXL_REFINER_SEAMLESS,
} from './constants';
import { getSDXLStylePrompts } from './graphBuilderUtils';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addSDXLRefinerToGraph = (
export const addSDXLRefinerToGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string,
modelLoaderNodeId?: string,
canvasInitImage?: ImageDTO,
canvasMaskImage?: ImageDTO
): void => {
): Promise<void> => {
const {
refinerModel,
refinerPositiveAestheticScore,
@ -55,9 +57,10 @@ export const addSDXLRefinerToGraph = (
const fp32 = vaePrecision === 'fp32';
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
upsertMetadata(graph, {
refiner_model: refinerModel,
refiner_model: getModelMetadataField(modelConfig),
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
refiner_cfg_scale: refinerCFGScale,

View File

@ -1,18 +1,22 @@
import type { RootState } from 'app/store/store';
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
import { omit } from 'lodash-es';
import type {
CollectInvocation,
CoreMetadataInvocation,
NonNullableGraph,
T2IAdapterField,
T2IAdapterInvocation,
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type CollectInvocation,
type CoreMetadataInvocation,
isT2IAdapterModelConfig,
type NonNullableGraph,
type T2IAdapterInvocation,
} from 'services/api/types';
import { T2I_ADAPTER_COLLECT } from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
export const addT2IAdaptersToLinearGraph = async (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): Promise<void> => {
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
(ca) => ca.model?.base === state.generation.model?.base
);
@ -35,7 +39,7 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
validT2IAdapters.forEach((t2iAdapter) => {
validT2IAdapters.forEach(async (t2iAdapter) => {
if (!t2iAdapter.model) {
return;
}
@ -77,9 +81,18 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
return;
}
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
t2iAdapterMetdata.push(omit(t2iAdapterNode, ['id', 'type', 'is_intermediate']) as T2IAdapterField);
const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig);
t2iAdapterMetdata.push({
begin_step_percent: beginStepPct,
end_step_percent: endStepPct,
resize_mode: resizeMode,
t2i_adapter_model: getModelMetadataField(modelConfig),
weight: weight,
image: t2iAdapterNode.image,
});
graph.edges.push({
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },

View File

@ -1,5 +1,7 @@
import type { RootState } from 'app/store/store';
import type { NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import type { ModelMetadataField, NonNullableGraph } from 'services/api/types';
import { isVAEModelConfig } from 'services/api/types';
import {
CANVAS_IMAGE_TO_IMAGE_GRAPH,
@ -23,13 +25,13 @@ import {
TEXT_TO_IMAGE_GRAPH,
VAE_LOADER,
} from './constants';
import { upsertMetadata } from './metadata';
import { getModelMetadataField, upsertMetadata } from './metadata';
export const addVAEToGraph = (
export const addVAEToGraph = async (
state: RootState,
graph: NonNullableGraph,
modelLoaderNodeId: string = MAIN_MODEL_LOADER
): void => {
): Promise<void> => {
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
const { boundingBoxScaleMethod } = state.canvas;
const { refinerModel } = state.sdxl;
@ -149,6 +151,8 @@ export const addVAEToGraph = (
}
if (vae) {
upsertMetadata(graph, { vae });
const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig);
const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig);
upsertMetadata(graph, { vae: vaeMetadata });
}
};

View File

@ -10,46 +10,46 @@ import { buildCanvasSDXLOutpaintGraph } from './buildCanvasSDXLOutpaintGraph';
import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
export const buildCanvasGraph = (
export const buildCanvasGraph = async (
state: RootState,
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
canvasInitImage: ImageDTO | undefined,
canvasMaskImage: ImageDTO | undefined
) => {
): Promise<NonNullableGraph> => {
let graph: NonNullableGraph;
if (generationMode === 'txt2img') {
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLTextToImageGraph(state);
graph = await buildCanvasSDXLTextToImageGraph(state);
} else {
graph = buildCanvasTextToImageGraph(state);
graph = await buildCanvasTextToImageGraph(state);
}
} else if (generationMode === 'img2img') {
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
graph = await buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
} else {
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
graph = await buildCanvasImageToImageGraph(state, canvasInitImage);
}
} else if (generationMode === 'inpaint') {
if (!canvasInitImage || !canvasMaskImage) {
throw new Error('Missing canvas init and mask images');
}
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = await buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = await buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
}
} else {
if (!canvasInitImage) {
throw new Error('Missing canvas init image');
}
if (state.generation.model && state.generation.model.base === 'sdxl') {
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = await buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
} else {
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
graph = await buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
}
}

View File

@ -1,7 +1,13 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import {
type ImageDTO,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -25,12 +31,15 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Canvas tab's Image to Image graph.
*/
export const buildCanvasImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
export const buildCanvasImageToImageGraph = async (
state: RootState,
initialImage: ImageDTO
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -306,6 +315,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -316,7 +327,7 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -335,17 +346,17 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
}
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS);
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -40,11 +40,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
/**
* Builds the Canvas tab's Inpaint graph.
*/
export const buildCanvasInpaintGraph = (
export const buildCanvasInpaintGraph = async (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage: ImageDTO
): NonNullableGraph => {
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -414,17 +414,17 @@ export const buildCanvasInpaintGraph = (
}
// Add VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!

View File

@ -44,11 +44,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
/**
* Builds the Canvas tab's Outpaint graph.
*/
export const buildCanvasOutpaintGraph = (
export const buildCanvasOutpaintGraph = async (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage?: ImageDTO
): NonNullableGraph => {
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -545,18 +545,18 @@ export const buildCanvasOutpaintGraph = (
}
// Add VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,6 +1,12 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type ImageDTO,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -26,12 +32,15 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Canvas tab's Image to Image graph.
*/
export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
export const buildCanvasSDXLImageToImageGraph = async (
state: RootState,
initialImage: ImageDTO
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -307,6 +316,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -317,7 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -338,24 +349,24 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -41,11 +41,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
/**
* Builds the Canvas tab's Inpaint graph.
*/
export const buildCanvasSDXLInpaintGraph = (
export const buildCanvasSDXLInpaintGraph = async (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage: ImageDTO
): NonNullableGraph => {
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -426,24 +426,31 @@ export const buildCanvasSDXLInpaintGraph = (
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage, canvasMaskImage);
await addSDXLRefinerToGraph(
state,
graph,
SDXL_DENOISE_LATENTS,
modelLoaderNodeId,
canvasInitImage,
canvasMaskImage
);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// Add VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -45,11 +45,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
/**
* Builds the Canvas tab's Outpaint graph.
*/
export const buildCanvasSDXLOutpaintGraph = (
export const buildCanvasSDXLOutpaintGraph = async (
state: RootState,
canvasInitImage: ImageDTO,
canvasMaskImage?: ImageDTO
): NonNullableGraph => {
): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -555,25 +555,25 @@ export const buildCanvasSDXLOutpaintGraph = (
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// Add VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -24,12 +25,12 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Canvas tab's Text to Image graph.
*/
export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -272,6 +273,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -284,7 +287,7 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
negative_prompt: negativePrompt,
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -301,24 +304,24 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,7 +1,8 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -23,12 +24,12 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Canvas tab's Text to Image graph.
*/
export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph => {
export const buildCanvasTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -262,6 +263,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -272,7 +275,7 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -289,17 +292,17 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,7 +1,13 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import {
type ImageResizeInvocation,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -24,12 +30,12 @@ import {
RESIZE,
SEAMLESS,
} from './constants';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Image to Image tab graph.
*/
export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph => {
export const buildLinearImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -307,6 +313,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -317,7 +325,7 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -336,17 +344,17 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,6 +1,12 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import {
type ImageResizeInvocation,
type ImageToLatentsInvocation,
isNonRefinerMainModelConfig,
type NonNullableGraph,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -25,12 +31,12 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
/**
* Builds the Image to Image tab graph.
*/
export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableGraph => {
export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -318,6 +324,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
});
}
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -328,7 +336,7 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -349,25 +357,25 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// Add LoRA Support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,6 +1,7 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { NonNullableGraph } from 'services/api/types';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
@ -23,9 +24,9 @@ import {
SEAMLESS,
} from './constants';
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -221,6 +222,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -231,7 +234,7 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -250,25 +253,25 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
// Add Refiner if enabled
if (refinerModel) {
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
if (seamlessXAxis || seamlessYAxis) {
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
}
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {

View File

@ -1,7 +1,8 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
import type { NonNullableGraph } from 'services/api/types';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addHrfToGraph } from './addHrfToGraph';
@ -23,9 +24,9 @@ import {
SEAMLESS,
TEXT_TO_IMAGE_GRAPH,
} from './constants';
import { addCoreMetadataNode } from './metadata';
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph => {
export const buildLinearTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
const log = logger('nodes');
const {
positivePrompt,
@ -212,6 +213,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
],
};
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
addCoreMetadataNode(
graph,
{
@ -222,7 +225,7 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
width,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model,
model: getModelMetadataField(modelConfig),
seed,
steps,
rand_device: use_cpu ? 'cpu' : 'cuda',
@ -239,18 +242,18 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
}
// optionally add custom VAE
addVAEToGraph(state, graph, modelLoaderNodeId);
await addVAEToGraph(state, graph, modelLoaderNodeId);
// add LoRA support
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// High resolution fix.
if (state.hrf.hrfEnabled) {

View File

@ -1,5 +1,5 @@
import type { JSONObject } from 'common/types';
import type { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types';
import { METADATA } from './constants';
@ -71,3 +71,11 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string
},
});
};
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({
key,
hash,
name,
base,
type,
});

View File

@ -38,7 +38,6 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
export type ModelType = S['ModelType'];
export type SubModelType = S['SubModelType'];
export type BaseModelType = S['BaseModelType'];
export type ControlField = S['ControlField'];
// Model Configs
@ -129,8 +128,9 @@ export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation'];
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];
export type IPAdapterMetadataField = S['IPAdapterMetadataField'];
export type T2IAdapterField = S['T2IAdapterField'];
// Metadata fields
export type ModelMetadataField = S['ModelMetadataField'];
// ControlNet Nodes
export type ControlNetInvocation = S['ControlNetInvocation'];