mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix metadata for graphs to use new enriched format
This commit is contained in:
parent
d4a2ea68fc
commit
c57f6ee885
@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
||||
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
||||
|
||||
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
|
||||
|
||||
|
@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
|
||||
if (model && model.base === 'sdxl') {
|
||||
if (action.payload.tabName === 'txt2img') {
|
||||
graph = buildLinearSDXLTextToImageGraph(state);
|
||||
graph = await buildLinearSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildLinearSDXLImageToImageGraph(state);
|
||||
graph = await buildLinearSDXLImageToImageGraph(state);
|
||||
}
|
||||
} else {
|
||||
if (action.payload.tabName === 'txt2img') {
|
||||
graph = buildLinearTextToImageGraph(state);
|
||||
graph = await buildLinearTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildLinearImageToImageGraph(state);
|
||||
graph = await buildLinearImageToImageGraph(state);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,8 +55,22 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
const zSubModelType = z.enum([
|
||||
'unet',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
'scheduler',
|
||||
'safety_checker',
|
||||
]);
|
||||
|
||||
const zModelIdentifier = z.object({
|
||||
key: z.string().min(1),
|
||||
submodel_type: zSubModelType.nullish(),
|
||||
});
|
||||
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
||||
zModelIdentifier.safeParse(field).success;
|
||||
|
@ -1,18 +1,22 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
ControlField,
|
||||
ControlNetInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { isControlNetModelConfig } from 'services/api/types';
|
||||
|
||||
import { CONTROL_NET_COLLECT } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
export const addControlNetToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
@ -39,7 +43,7 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
|
||||
},
|
||||
});
|
||||
|
||||
validControlNets.forEach((controlNet) => {
|
||||
validControlNets.forEach(async (controlNet) => {
|
||||
if (!controlNet.model) {
|
||||
return;
|
||||
}
|
||||
@ -85,7 +89,17 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
|
||||
|
||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
||||
|
||||
controlNetMetadata.push(omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField);
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig);
|
||||
|
||||
controlNetMetadata.push({
|
||||
control_model: getModelMetadataField(modelConfig),
|
||||
control_weight: weight,
|
||||
control_mode: controlMode,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
resize_mode: resizeMode,
|
||||
image: controlNetNode.image,
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: controlNetNode.id, field: 'control' },
|
||||
|
@ -1,18 +1,22 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
IPAdapterInvocation,
|
||||
IPAdapterMetadataField,
|
||||
NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
import { isIPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
import { IP_ADAPTER_COLLECT } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
export const addIPAdapterToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
@ -35,7 +39,7 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
|
||||
|
||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||
|
||||
validIPAdapters.forEach((ipAdapter) => {
|
||||
validIPAdapters.forEach(async (ipAdapter) => {
|
||||
if (!ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
@ -58,9 +62,17 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||
|
||||
ipAdapterMetdata.push(omit(ipAdapterNode, ['id', 'type', 'is_intermediate']) as IPAdapterMetadataField);
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig);
|
||||
|
||||
ipAdapterMetdata.push({
|
||||
weight: weight,
|
||||
ip_adapter_model: getModelMetadataField(modelConfig),
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
image: ipAdapterNode.image,
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||
|
@ -1,16 +1,22 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import {
|
||||
type CoreMetadataInvocation,
|
||||
isLoRAModelConfig,
|
||||
type LoraLoaderInvocation,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addLoRAsToGraph = (
|
||||
export const addLoRAsToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
): Promise<void> => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||
@ -39,7 +45,7 @@ export const addLoRAsToGraph = (
|
||||
let currentLoraIndex = 0;
|
||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
enabledLoRAs.forEach(async (lora) => {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
@ -52,8 +58,10 @@ export const addLoRAsToGraph = (
|
||||
weight,
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
loraMetadata.push({
|
||||
model: { key },
|
||||
model: getModelMetadataField(modelConfig),
|
||||
weight,
|
||||
});
|
||||
|
||||
|
@ -1,6 +1,12 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { filter, size } from 'lodash-es';
|
||||
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types';
|
||||
import {
|
||||
type CoreMetadataInvocation,
|
||||
isLoRAModelConfig,
|
||||
type NonNullableGraph,
|
||||
type SDXLLoraLoaderInvocation,
|
||||
} from 'services/api/types';
|
||||
|
||||
import {
|
||||
LORA_LOADER,
|
||||
@ -10,14 +16,14 @@ import {
|
||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLLoRAsToGraph = (
|
||||
export const addSDXLLoRAsToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId: string = SDXL_MODEL_LOADER
|
||||
): void => {
|
||||
): Promise<void> => {
|
||||
/**
|
||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||
@ -55,7 +61,7 @@ export const addSDXLLoRAsToGraph = (
|
||||
let lastLoraNodeId = '';
|
||||
let currentLoraIndex = 0;
|
||||
|
||||
enabledLoRAs.forEach((lora) => {
|
||||
enabledLoRAs.forEach(async (lora) => {
|
||||
const { weight } = lora;
|
||||
const { key } = lora.model;
|
||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||
@ -68,7 +74,9 @@ export const addSDXLLoRAsToGraph = (
|
||||
weight,
|
||||
};
|
||||
|
||||
loraMetadata.push({ model: { key }, weight });
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
|
||||
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
|
||||
|
||||
// add to graph
|
||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||
|
@ -1,9 +1,11 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type {
|
||||
CreateDenoiseMaskInvocation,
|
||||
ImageDTO,
|
||||
NonNullableGraph,
|
||||
SeamlessModeInvocation,
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type CreateDenoiseMaskInvocation,
|
||||
type ImageDTO,
|
||||
isRefinerMainModelModelConfig,
|
||||
type NonNullableGraph,
|
||||
type SeamlessModeInvocation,
|
||||
} from 'services/api/types';
|
||||
|
||||
import {
|
||||
@ -25,16 +27,16 @@ import {
|
||||
SDXL_REFINER_SEAMLESS,
|
||||
} from './constants';
|
||||
import { getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addSDXLRefinerToGraph = (
|
||||
export const addSDXLRefinerToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string,
|
||||
modelLoaderNodeId?: string,
|
||||
canvasInitImage?: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): void => {
|
||||
): Promise<void> => {
|
||||
const {
|
||||
refinerModel,
|
||||
refinerPositiveAestheticScore,
|
||||
@ -55,9 +57,10 @@ export const addSDXLRefinerToGraph = (
|
||||
const fp32 = vaePrecision === 'fp32';
|
||||
|
||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
|
||||
|
||||
upsertMetadata(graph, {
|
||||
refiner_model: refinerModel,
|
||||
refiner_model: getModelMetadataField(modelConfig),
|
||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||
refiner_cfg_scale: refinerCFGScale,
|
||||
|
@ -1,18 +1,22 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { omit } from 'lodash-es';
|
||||
import type {
|
||||
CollectInvocation,
|
||||
CoreMetadataInvocation,
|
||||
NonNullableGraph,
|
||||
T2IAdapterField,
|
||||
T2IAdapterInvocation,
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type CollectInvocation,
|
||||
type CoreMetadataInvocation,
|
||||
isT2IAdapterModelConfig,
|
||||
type NonNullableGraph,
|
||||
type T2IAdapterInvocation,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { T2I_ADAPTER_COLLECT } from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
||||
export const addT2IAdaptersToLinearGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
baseNodeId: string
|
||||
): Promise<void> => {
|
||||
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
|
||||
(ca) => ca.model?.base === state.generation.model?.base
|
||||
);
|
||||
@ -35,7 +39,7 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
|
||||
|
||||
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||
|
||||
validT2IAdapters.forEach((t2iAdapter) => {
|
||||
validT2IAdapters.forEach(async (t2iAdapter) => {
|
||||
if (!t2iAdapter.model) {
|
||||
return;
|
||||
}
|
||||
@ -77,9 +81,18 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
|
||||
return;
|
||||
}
|
||||
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
|
||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
||||
|
||||
t2iAdapterMetdata.push(omit(t2iAdapterNode, ['id', 'type', 'is_intermediate']) as T2IAdapterField);
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig);
|
||||
|
||||
t2iAdapterMetdata.push({
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
resize_mode: resizeMode,
|
||||
t2i_adapter_model: getModelMetadataField(modelConfig),
|
||||
weight: weight,
|
||||
image: t2iAdapterNode.image,
|
||||
});
|
||||
|
||||
graph.edges.push({
|
||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||
|
@ -1,5 +1,7 @@
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
||||
import { isVAEModelConfig } from 'services/api/types';
|
||||
|
||||
import {
|
||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||
@ -23,13 +25,13 @@ import {
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
VAE_LOADER,
|
||||
} from './constants';
|
||||
import { upsertMetadata } from './metadata';
|
||||
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||
|
||||
export const addVAEToGraph = (
|
||||
export const addVAEToGraph = async (
|
||||
state: RootState,
|
||||
graph: NonNullableGraph,
|
||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||
): void => {
|
||||
): Promise<void> => {
|
||||
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
|
||||
const { boundingBoxScaleMethod } = state.canvas;
|
||||
const { refinerModel } = state.sdxl;
|
||||
@ -149,6 +151,8 @@ export const addVAEToGraph = (
|
||||
}
|
||||
|
||||
if (vae) {
|
||||
upsertMetadata(graph, { vae });
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig);
|
||||
const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig);
|
||||
upsertMetadata(graph, { vae: vaeMetadata });
|
||||
}
|
||||
};
|
||||
|
@ -10,46 +10,46 @@ import { buildCanvasSDXLOutpaintGraph } from './buildCanvasSDXLOutpaintGraph';
|
||||
import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
|
||||
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
|
||||
|
||||
export const buildCanvasGraph = (
|
||||
export const buildCanvasGraph = async (
|
||||
state: RootState,
|
||||
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
||||
canvasInitImage: ImageDTO | undefined,
|
||||
canvasMaskImage: ImageDTO | undefined
|
||||
) => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
let graph: NonNullableGraph;
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLTextToImageGraph(state);
|
||||
graph = await buildCanvasSDXLTextToImageGraph(state);
|
||||
} else {
|
||||
graph = buildCanvasTextToImageGraph(state);
|
||||
graph = await buildCanvasTextToImageGraph(state);
|
||||
}
|
||||
} else if (generationMode === 'img2img') {
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
||||
graph = await buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
||||
} else {
|
||||
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
graph = await buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||
}
|
||||
} else if (generationMode === 'inpaint') {
|
||||
if (!canvasInitImage || !canvasMaskImage) {
|
||||
throw new Error('Missing canvas init and mask images');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = await buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = await buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
}
|
||||
} else {
|
||||
if (!canvasInitImage) {
|
||||
throw new Error('Missing canvas init image');
|
||||
}
|
||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = await buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
} else {
|
||||
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
graph = await buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,13 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import {
|
||||
type ImageDTO,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -25,12 +31,15 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
*/
|
||||
export const buildCanvasImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
|
||||
export const buildCanvasImageToImageGraph = async (
|
||||
state: RootState,
|
||||
initialImage: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -306,6 +315,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -316,7 +327,7 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -335,17 +346,17 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
||||
}
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -40,11 +40,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
/**
|
||||
* Builds the Canvas tab's Inpaint graph.
|
||||
*/
|
||||
export const buildCanvasInpaintGraph = (
|
||||
export const buildCanvasInpaintGraph = async (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -414,17 +414,17 @@ export const buildCanvasInpaintGraph = (
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
// must add before watermarker!
|
||||
|
@ -44,11 +44,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
||||
/**
|
||||
* Builds the Canvas tab's Outpaint graph.
|
||||
*/
|
||||
export const buildCanvasOutpaintGraph = (
|
||||
export const buildCanvasOutpaintGraph = async (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -545,18 +545,18 @@ export const buildCanvasOutpaintGraph = (
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,6 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type ImageDTO,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -26,12 +32,15 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Image to Image graph.
|
||||
*/
|
||||
export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
|
||||
export const buildCanvasSDXLImageToImageGraph = async (
|
||||
state: RootState,
|
||||
initialImage: ImageDTO
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -307,6 +316,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -317,7 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -338,24 +349,24 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -41,11 +41,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
|
||||
/**
|
||||
* Builds the Canvas tab's Inpaint graph.
|
||||
*/
|
||||
export const buildCanvasSDXLInpaintGraph = (
|
||||
export const buildCanvasSDXLInpaintGraph = async (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -426,24 +426,31 @@ export const buildCanvasSDXLInpaintGraph = (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage, canvasMaskImage);
|
||||
await addSDXLRefinerToGraph(
|
||||
state,
|
||||
graph,
|
||||
SDXL_DENOISE_LATENTS,
|
||||
modelLoaderNodeId,
|
||||
canvasInitImage,
|
||||
canvasMaskImage
|
||||
);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -45,11 +45,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
|
||||
/**
|
||||
* Builds the Canvas tab's Outpaint graph.
|
||||
*/
|
||||
export const buildCanvasSDXLOutpaintGraph = (
|
||||
export const buildCanvasSDXLOutpaintGraph = async (
|
||||
state: RootState,
|
||||
canvasInitImage: ImageDTO,
|
||||
canvasMaskImage?: ImageDTO
|
||||
): NonNullableGraph => {
|
||||
): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -555,25 +555,25 @@ export const buildCanvasSDXLOutpaintGraph = (
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// Add VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -24,12 +25,12 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
*/
|
||||
export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -272,6 +273,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -284,7 +287,7 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
negative_prompt: negativePrompt,
|
||||
positive_style_prompt: positiveStylePrompt,
|
||||
negative_style_prompt: negativeStylePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -301,24 +304,24 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,7 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -23,12 +24,12 @@ import {
|
||||
POSITIVE_CONDITIONING,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Canvas tab's Text to Image graph.
|
||||
*/
|
||||
export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildCanvasTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -262,6 +263,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -272,7 +275,7 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -289,17 +292,17 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,7 +1,13 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import {
|
||||
type ImageResizeInvocation,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -24,12 +30,12 @@ import {
|
||||
RESIZE,
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildLinearImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -307,6 +313,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -317,7 +325,7 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -336,17 +344,17 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,6 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
type ImageResizeInvocation,
|
||||
type ImageToLatentsInvocation,
|
||||
isNonRefinerMainModelConfig,
|
||||
type NonNullableGraph,
|
||||
} from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -25,12 +31,12 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
/**
|
||||
* Builds the Image to Image tab graph.
|
||||
*/
|
||||
export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -318,6 +324,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
||||
});
|
||||
}
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -328,7 +336,7 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -349,25 +357,25 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// Add LoRA Support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// Add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||
@ -23,9 +24,9 @@ import {
|
||||
SEAMLESS,
|
||||
} from './constants';
|
||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -221,6 +222,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
],
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -231,7 +234,7 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -250,25 +253,25 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
||||
|
||||
// Add Refiner if enabled
|
||||
if (refinerModel) {
|
||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
if (seamlessXAxis || seamlessYAxis) {
|
||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||
}
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||
|
||||
// NSFW & watermark - must be last thing added to graph
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
|
@ -1,7 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { NonNullableGraph } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||
import { addHrfToGraph } from './addHrfToGraph';
|
||||
@ -23,9 +24,9 @@ import {
|
||||
SEAMLESS,
|
||||
TEXT_TO_IMAGE_GRAPH,
|
||||
} from './constants';
|
||||
import { addCoreMetadataNode } from './metadata';
|
||||
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||
|
||||
export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph => {
|
||||
export const buildLinearTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||
const log = logger('nodes');
|
||||
const {
|
||||
positivePrompt,
|
||||
@ -212,6 +213,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
],
|
||||
};
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
|
||||
addCoreMetadataNode(
|
||||
graph,
|
||||
{
|
||||
@ -222,7 +225,7 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
width,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model,
|
||||
model: getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||
@ -239,18 +242,18 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
||||
}
|
||||
|
||||
// optionally add custom VAE
|
||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||
|
||||
// add LoRA support
|
||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||
|
||||
// add controlnet, mutating `graph`
|
||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// add IP Adapter
|
||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||
|
||||
// High resolution fix.
|
||||
if (state.hrf.hrfEnabled) {
|
||||
|
@ -1,5 +1,5 @@
|
||||
import type { JSONObject } from 'common/types';
|
||||
import type { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
|
||||
import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
||||
|
||||
import { METADATA } from './constants';
|
||||
|
||||
@ -71,3 +71,11 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({
|
||||
key,
|
||||
hash,
|
||||
name,
|
||||
base,
|
||||
type,
|
||||
});
|
||||
|
@ -38,7 +38,6 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
|
||||
export type ModelType = S['ModelType'];
|
||||
export type SubModelType = S['SubModelType'];
|
||||
export type BaseModelType = S['BaseModelType'];
|
||||
export type ControlField = S['ControlField'];
|
||||
|
||||
// Model Configs
|
||||
|
||||
@ -129,8 +128,9 @@ export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation'];
|
||||
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
|
||||
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
|
||||
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];
|
||||
export type IPAdapterMetadataField = S['IPAdapterMetadataField'];
|
||||
export type T2IAdapterField = S['T2IAdapterField'];
|
||||
|
||||
// Metadata fields
|
||||
export type ModelMetadataField = S['ModelMetadataField'];
|
||||
|
||||
// ControlNet Nodes
|
||||
export type ControlNetInvocation = S['ControlNetInvocation'];
|
||||
|
Loading…
x
Reference in New Issue
Block a user