mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix metadata for graphs to use new enriched format
This commit is contained in:
parent
d4a2ea68fc
commit
c57f6ee885
@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis
|
|||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage);
|
||||||
|
|
||||||
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
|
log.debug({ graph: parseify(graph) }, `Canvas graph built`);
|
||||||
|
|
||||||
|
@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
|||||||
|
|
||||||
if (model && model.base === 'sdxl') {
|
if (model && model.base === 'sdxl') {
|
||||||
if (action.payload.tabName === 'txt2img') {
|
if (action.payload.tabName === 'txt2img') {
|
||||||
graph = buildLinearSDXLTextToImageGraph(state);
|
graph = await buildLinearSDXLTextToImageGraph(state);
|
||||||
} else {
|
} else {
|
||||||
graph = buildLinearSDXLImageToImageGraph(state);
|
graph = await buildLinearSDXLImageToImageGraph(state);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (action.payload.tabName === 'txt2img') {
|
if (action.payload.tabName === 'txt2img') {
|
||||||
graph = buildLinearTextToImageGraph(state);
|
graph = await buildLinearTextToImageGraph(state);
|
||||||
} else {
|
} else {
|
||||||
graph = buildLinearImageToImageGraph(state);
|
graph = await buildLinearImageToImageGraph(state);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,8 +55,22 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
|||||||
|
|
||||||
// #region Model-related schemas
|
// #region Model-related schemas
|
||||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||||
|
const zSubModelType = z.enum([
|
||||||
|
'unet',
|
||||||
|
'text_encoder',
|
||||||
|
'text_encoder_2',
|
||||||
|
'tokenizer',
|
||||||
|
'tokenizer_2',
|
||||||
|
'vae',
|
||||||
|
'vae_decoder',
|
||||||
|
'vae_encoder',
|
||||||
|
'scheduler',
|
||||||
|
'safety_checker',
|
||||||
|
]);
|
||||||
|
|
||||||
const zModelIdentifier = z.object({
|
const zModelIdentifier = z.object({
|
||||||
key: z.string().min(1),
|
key: z.string().min(1),
|
||||||
|
submodel_type: zSubModelType.nullish(),
|
||||||
});
|
});
|
||||||
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
export const isModelIdentifier = (field: unknown): field is ModelIdentifier =>
|
||||||
zModelIdentifier.safeParse(field).success;
|
zModelIdentifier.safeParse(field).success;
|
||||||
|
@ -1,18 +1,22 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { omit } from 'lodash-es';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import type {
|
import type {
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
ControlField,
|
|
||||||
ControlNetInvocation,
|
ControlNetInvocation,
|
||||||
CoreMetadataInvocation,
|
CoreMetadataInvocation,
|
||||||
NonNullableGraph,
|
NonNullableGraph,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
|
import { isControlNetModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { CONTROL_NET_COLLECT } from './constants';
|
import { CONTROL_NET_COLLECT } from './constants';
|
||||||
import { upsertMetadata } from './metadata';
|
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
export const addControlNetToLinearGraph = async (
|
||||||
|
state: RootState,
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
baseNodeId: string
|
||||||
|
): Promise<void> => {
|
||||||
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
|
const validControlNets = selectValidControlNets(state.controlAdapters).filter(
|
||||||
(ca) => ca.model?.base === state.generation.model?.base
|
(ca) => ca.model?.base === state.generation.model?.base
|
||||||
);
|
);
|
||||||
@ -39,7 +43,7 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
validControlNets.forEach((controlNet) => {
|
validControlNets.forEach(async (controlNet) => {
|
||||||
if (!controlNet.model) {
|
if (!controlNet.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -85,7 +89,17 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG
|
|||||||
|
|
||||||
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation;
|
||||||
|
|
||||||
controlNetMetadata.push(omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField);
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig);
|
||||||
|
|
||||||
|
controlNetMetadata.push({
|
||||||
|
control_model: getModelMetadataField(modelConfig),
|
||||||
|
control_weight: weight,
|
||||||
|
control_mode: controlMode,
|
||||||
|
begin_step_percent: beginStepPct,
|
||||||
|
end_step_percent: endStepPct,
|
||||||
|
resize_mode: resizeMode,
|
||||||
|
image: controlNetNode.image,
|
||||||
|
});
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: controlNetNode.id, field: 'control' },
|
source: { node_id: controlNetNode.id, field: 'control' },
|
||||||
|
@ -1,18 +1,22 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { omit } from 'lodash-es';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import type {
|
import type {
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
CoreMetadataInvocation,
|
CoreMetadataInvocation,
|
||||||
IPAdapterInvocation,
|
IPAdapterInvocation,
|
||||||
IPAdapterMetadataField,
|
|
||||||
NonNullableGraph,
|
NonNullableGraph,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
|
import { isIPAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { IP_ADAPTER_COLLECT } from './constants';
|
import { IP_ADAPTER_COLLECT } from './constants';
|
||||||
import { upsertMetadata } from './metadata';
|
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
export const addIPAdapterToLinearGraph = async (
|
||||||
|
state: RootState,
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
baseNodeId: string
|
||||||
|
): Promise<void> => {
|
||||||
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
|
const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter(
|
||||||
(ca) => ca.model?.base === state.generation.model?.base
|
(ca) => ca.model?.base === state.generation.model?.base
|
||||||
);
|
);
|
||||||
@ -35,7 +39,7 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
|
|||||||
|
|
||||||
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = [];
|
||||||
|
|
||||||
validIPAdapters.forEach((ipAdapter) => {
|
validIPAdapters.forEach(async (ipAdapter) => {
|
||||||
if (!ipAdapter.model) {
|
if (!ipAdapter.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -58,9 +62,17 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
graph.nodes[ipAdapterNode.id] = ipAdapterNode;
|
||||||
|
|
||||||
ipAdapterMetdata.push(omit(ipAdapterNode, ['id', 'type', 'is_intermediate']) as IPAdapterMetadataField);
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig);
|
||||||
|
|
||||||
|
ipAdapterMetdata.push({
|
||||||
|
weight: weight,
|
||||||
|
ip_adapter_model: getModelMetadataField(modelConfig),
|
||||||
|
begin_step_percent: beginStepPct,
|
||||||
|
end_step_percent: endStepPct,
|
||||||
|
image: ipAdapterNode.image,
|
||||||
|
});
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||||
|
@ -1,16 +1,22 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { filter, size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types';
|
import {
|
||||||
|
type CoreMetadataInvocation,
|
||||||
|
isLoRAModelConfig,
|
||||||
|
type LoraLoaderInvocation,
|
||||||
|
type NonNullableGraph,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants';
|
||||||
import { upsertMetadata } from './metadata';
|
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addLoRAsToGraph = (
|
export const addLoRAsToGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
baseNodeId: string,
|
baseNodeId: string,
|
||||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||||
): void => {
|
): Promise<void> => {
|
||||||
/**
|
/**
|
||||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||||
@ -39,7 +45,7 @@ export const addLoRAsToGraph = (
|
|||||||
let currentLoraIndex = 0;
|
let currentLoraIndex = 0;
|
||||||
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
const loraMetadata: CoreMetadataInvocation['loras'] = [];
|
||||||
|
|
||||||
enabledLoRAs.forEach((lora) => {
|
enabledLoRAs.forEach(async (lora) => {
|
||||||
const { weight } = lora;
|
const { weight } = lora;
|
||||||
const { key } = lora.model;
|
const { key } = lora.model;
|
||||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||||
@ -52,8 +58,10 @@ export const addLoRAsToGraph = (
|
|||||||
weight,
|
weight,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||||
|
|
||||||
loraMetadata.push({
|
loraMetadata.push({
|
||||||
model: { key },
|
model: getModelMetadataField(modelConfig),
|
||||||
weight,
|
weight,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { filter, size } from 'lodash-es';
|
import { filter, size } from 'lodash-es';
|
||||||
import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types';
|
import {
|
||||||
|
type CoreMetadataInvocation,
|
||||||
|
isLoRAModelConfig,
|
||||||
|
type NonNullableGraph,
|
||||||
|
type SDXLLoraLoaderInvocation,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
LORA_LOADER,
|
LORA_LOADER,
|
||||||
@ -10,14 +16,14 @@ import {
|
|||||||
SDXL_REFINER_INPAINT_CREATE_MASK,
|
SDXL_REFINER_INPAINT_CREATE_MASK,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { upsertMetadata } from './metadata';
|
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addSDXLLoRAsToGraph = (
|
export const addSDXLLoRAsToGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
baseNodeId: string,
|
baseNodeId: string,
|
||||||
modelLoaderNodeId: string = SDXL_MODEL_LOADER
|
modelLoaderNodeId: string = SDXL_MODEL_LOADER
|
||||||
): void => {
|
): Promise<void> => {
|
||||||
/**
|
/**
|
||||||
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
* LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them.
|
||||||
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
* They then output the UNet and CLIP models references on to either the next LoRA in the chain,
|
||||||
@ -55,7 +61,7 @@ export const addSDXLLoRAsToGraph = (
|
|||||||
let lastLoraNodeId = '';
|
let lastLoraNodeId = '';
|
||||||
let currentLoraIndex = 0;
|
let currentLoraIndex = 0;
|
||||||
|
|
||||||
enabledLoRAs.forEach((lora) => {
|
enabledLoRAs.forEach(async (lora) => {
|
||||||
const { weight } = lora;
|
const { weight } = lora;
|
||||||
const { key } = lora.model;
|
const { key } = lora.model;
|
||||||
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
const currentLoraNodeId = `${LORA_LOADER}_${key}`;
|
||||||
@ -68,7 +74,9 @@ export const addSDXLLoRAsToGraph = (
|
|||||||
weight,
|
weight,
|
||||||
};
|
};
|
||||||
|
|
||||||
loraMetadata.push({ model: { key }, weight });
|
const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||||
|
|
||||||
|
loraMetadata.push({ model: getModelMetadataField(modelConfig), weight });
|
||||||
|
|
||||||
// add to graph
|
// add to graph
|
||||||
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
graph.nodes[currentLoraNodeId] = loraLoaderNode;
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import type {
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
CreateDenoiseMaskInvocation,
|
import {
|
||||||
ImageDTO,
|
type CreateDenoiseMaskInvocation,
|
||||||
NonNullableGraph,
|
type ImageDTO,
|
||||||
SeamlessModeInvocation,
|
isRefinerMainModelModelConfig,
|
||||||
|
type NonNullableGraph,
|
||||||
|
type SeamlessModeInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@ -25,16 +27,16 @@ import {
|
|||||||
SDXL_REFINER_SEAMLESS,
|
SDXL_REFINER_SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { getSDXLStylePrompts } from './graphBuilderUtils';
|
import { getSDXLStylePrompts } from './graphBuilderUtils';
|
||||||
import { upsertMetadata } from './metadata';
|
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addSDXLRefinerToGraph = (
|
export const addSDXLRefinerToGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
baseNodeId: string,
|
baseNodeId: string,
|
||||||
modelLoaderNodeId?: string,
|
modelLoaderNodeId?: string,
|
||||||
canvasInitImage?: ImageDTO,
|
canvasInitImage?: ImageDTO,
|
||||||
canvasMaskImage?: ImageDTO
|
canvasMaskImage?: ImageDTO
|
||||||
): void => {
|
): Promise<void> => {
|
||||||
const {
|
const {
|
||||||
refinerModel,
|
refinerModel,
|
||||||
refinerPositiveAestheticScore,
|
refinerPositiveAestheticScore,
|
||||||
@ -55,9 +57,10 @@ export const addSDXLRefinerToGraph = (
|
|||||||
const fp32 = vaePrecision === 'fp32';
|
const fp32 = vaePrecision === 'fp32';
|
||||||
|
|
||||||
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
|
const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod);
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig);
|
||||||
|
|
||||||
upsertMetadata(graph, {
|
upsertMetadata(graph, {
|
||||||
refiner_model: refinerModel,
|
refiner_model: getModelMetadataField(modelConfig),
|
||||||
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
refiner_positive_aesthetic_score: refinerPositiveAestheticScore,
|
||||||
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
refiner_negative_aesthetic_score: refinerNegativeAestheticScore,
|
||||||
refiner_cfg_scale: refinerCFGScale,
|
refiner_cfg_scale: refinerCFGScale,
|
||||||
|
@ -1,18 +1,22 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { omit } from 'lodash-es';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import type {
|
import {
|
||||||
CollectInvocation,
|
type CollectInvocation,
|
||||||
CoreMetadataInvocation,
|
type CoreMetadataInvocation,
|
||||||
NonNullableGraph,
|
isT2IAdapterModelConfig,
|
||||||
T2IAdapterField,
|
type NonNullableGraph,
|
||||||
T2IAdapterInvocation,
|
type T2IAdapterInvocation,
|
||||||
} from 'services/api/types';
|
} from 'services/api/types';
|
||||||
|
|
||||||
import { T2I_ADAPTER_COLLECT } from './constants';
|
import { T2I_ADAPTER_COLLECT } from './constants';
|
||||||
import { upsertMetadata } from './metadata';
|
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => {
|
export const addT2IAdaptersToLinearGraph = async (
|
||||||
|
state: RootState,
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
baseNodeId: string
|
||||||
|
): Promise<void> => {
|
||||||
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
|
const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter(
|
||||||
(ca) => ca.model?.base === state.generation.model?.base
|
(ca) => ca.model?.base === state.generation.model?.base
|
||||||
);
|
);
|
||||||
@ -35,7 +39,7 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
|
|||||||
|
|
||||||
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
|
const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = [];
|
||||||
|
|
||||||
validT2IAdapters.forEach((t2iAdapter) => {
|
validT2IAdapters.forEach(async (t2iAdapter) => {
|
||||||
if (!t2iAdapter.model) {
|
if (!t2iAdapter.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -77,9 +81,18 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation;
|
graph.nodes[t2iAdapterNode.id] = t2iAdapterNode;
|
||||||
|
|
||||||
t2iAdapterMetdata.push(omit(t2iAdapterNode, ['id', 'type', 'is_intermediate']) as T2IAdapterField);
|
const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig);
|
||||||
|
|
||||||
|
t2iAdapterMetdata.push({
|
||||||
|
begin_step_percent: beginStepPct,
|
||||||
|
end_step_percent: endStepPct,
|
||||||
|
resize_mode: resizeMode,
|
||||||
|
t2i_adapter_model: getModelMetadataField(modelConfig),
|
||||||
|
weight: weight,
|
||||||
|
image: t2iAdapterNode.image,
|
||||||
|
});
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' },
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import type { NonNullableGraph } from 'services/api/types';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import type { ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
||||||
|
import { isVAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
CANVAS_IMAGE_TO_IMAGE_GRAPH,
|
||||||
@ -23,13 +25,13 @@ import {
|
|||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
VAE_LOADER,
|
VAE_LOADER,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { upsertMetadata } from './metadata';
|
import { getModelMetadataField, upsertMetadata } from './metadata';
|
||||||
|
|
||||||
export const addVAEToGraph = (
|
export const addVAEToGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
graph: NonNullableGraph,
|
graph: NonNullableGraph,
|
||||||
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
modelLoaderNodeId: string = MAIN_MODEL_LOADER
|
||||||
): void => {
|
): Promise<void> => {
|
||||||
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
|
const { vae, seamlessXAxis, seamlessYAxis } = state.generation;
|
||||||
const { boundingBoxScaleMethod } = state.canvas;
|
const { boundingBoxScaleMethod } = state.canvas;
|
||||||
const { refinerModel } = state.sdxl;
|
const { refinerModel } = state.sdxl;
|
||||||
@ -149,6 +151,8 @@ export const addVAEToGraph = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (vae) {
|
if (vae) {
|
||||||
upsertMetadata(graph, { vae });
|
const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig);
|
||||||
|
const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig);
|
||||||
|
upsertMetadata(graph, { vae: vaeMetadata });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -10,46 +10,46 @@ import { buildCanvasSDXLOutpaintGraph } from './buildCanvasSDXLOutpaintGraph';
|
|||||||
import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
|
import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph';
|
||||||
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
|
import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph';
|
||||||
|
|
||||||
export const buildCanvasGraph = (
|
export const buildCanvasGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint',
|
||||||
canvasInitImage: ImageDTO | undefined,
|
canvasInitImage: ImageDTO | undefined,
|
||||||
canvasMaskImage: ImageDTO | undefined
|
canvasMaskImage: ImageDTO | undefined
|
||||||
) => {
|
): Promise<NonNullableGraph> => {
|
||||||
let graph: NonNullableGraph;
|
let graph: NonNullableGraph;
|
||||||
|
|
||||||
if (generationMode === 'txt2img') {
|
if (generationMode === 'txt2img') {
|
||||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||||
graph = buildCanvasSDXLTextToImageGraph(state);
|
graph = await buildCanvasSDXLTextToImageGraph(state);
|
||||||
} else {
|
} else {
|
||||||
graph = buildCanvasTextToImageGraph(state);
|
graph = await buildCanvasTextToImageGraph(state);
|
||||||
}
|
}
|
||||||
} else if (generationMode === 'img2img') {
|
} else if (generationMode === 'img2img') {
|
||||||
if (!canvasInitImage) {
|
if (!canvasInitImage) {
|
||||||
throw new Error('Missing canvas init image');
|
throw new Error('Missing canvas init image');
|
||||||
}
|
}
|
||||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||||
graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
graph = await buildCanvasSDXLImageToImageGraph(state, canvasInitImage);
|
||||||
} else {
|
} else {
|
||||||
graph = buildCanvasImageToImageGraph(state, canvasInitImage);
|
graph = await buildCanvasImageToImageGraph(state, canvasInitImage);
|
||||||
}
|
}
|
||||||
} else if (generationMode === 'inpaint') {
|
} else if (generationMode === 'inpaint') {
|
||||||
if (!canvasInitImage || !canvasMaskImage) {
|
if (!canvasInitImage || !canvasMaskImage) {
|
||||||
throw new Error('Missing canvas init and mask images');
|
throw new Error('Missing canvas init and mask images');
|
||||||
}
|
}
|
||||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||||
graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
graph = await buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||||
} else {
|
} else {
|
||||||
graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
graph = await buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!canvasInitImage) {
|
if (!canvasInitImage) {
|
||||||
throw new Error('Missing canvas init image');
|
throw new Error('Missing canvas init image');
|
||||||
}
|
}
|
||||||
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
if (state.generation.model && state.generation.model.base === 'sdxl') {
|
||||||
graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
graph = await buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||||
} else {
|
} else {
|
||||||
graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
graph = await buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
import {
|
||||||
|
type ImageDTO,
|
||||||
|
type ImageToLatentsInvocation,
|
||||||
|
isNonRefinerMainModelConfig,
|
||||||
|
type NonNullableGraph,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||||
@ -25,12 +31,15 @@ import {
|
|||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addCoreMetadataNode } from './metadata';
|
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Image to Image graph.
|
* Builds the Canvas tab's Image to Image graph.
|
||||||
*/
|
*/
|
||||||
export const buildCanvasImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
|
export const buildCanvasImageToImageGraph = async (
|
||||||
|
state: RootState,
|
||||||
|
initialImage: ImageDTO
|
||||||
|
): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -306,6 +315,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
addCoreMetadataNode(
|
addCoreMetadataNode(
|
||||||
graph,
|
graph,
|
||||||
{
|
{
|
||||||
@ -316,7 +327,7 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
|||||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
model,
|
model: getModelMetadataField(modelConfig),
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
@ -335,17 +346,17 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima
|
|||||||
}
|
}
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS);
|
await addLoRAsToGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -40,11 +40,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
|||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Inpaint graph.
|
* Builds the Canvas tab's Inpaint graph.
|
||||||
*/
|
*/
|
||||||
export const buildCanvasInpaintGraph = (
|
export const buildCanvasInpaintGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
canvasInitImage: ImageDTO,
|
canvasInitImage: ImageDTO,
|
||||||
canvasMaskImage: ImageDTO
|
canvasMaskImage: ImageDTO
|
||||||
): NonNullableGraph => {
|
): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -414,17 +414,17 @@ export const buildCanvasInpaintGraph = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add VAE
|
// Add VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
// must add before watermarker!
|
// must add before watermarker!
|
||||||
|
@ -44,11 +44,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
|
|||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Outpaint graph.
|
* Builds the Canvas tab's Outpaint graph.
|
||||||
*/
|
*/
|
||||||
export const buildCanvasOutpaintGraph = (
|
export const buildCanvasOutpaintGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
canvasInitImage: ImageDTO,
|
canvasInitImage: ImageDTO,
|
||||||
canvasMaskImage?: ImageDTO
|
canvasMaskImage?: ImageDTO
|
||||||
): NonNullableGraph => {
|
): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -545,18 +545,18 @@ export const buildCanvasOutpaintGraph = (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add VAE
|
// Add VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import {
|
||||||
|
type ImageDTO,
|
||||||
|
type ImageToLatentsInvocation,
|
||||||
|
isNonRefinerMainModelConfig,
|
||||||
|
type NonNullableGraph,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||||
@ -26,12 +32,15 @@ import {
|
|||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||||
import { addCoreMetadataNode } from './metadata';
|
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Image to Image graph.
|
* Builds the Canvas tab's Image to Image graph.
|
||||||
*/
|
*/
|
||||||
export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => {
|
export const buildCanvasSDXLImageToImageGraph = async (
|
||||||
|
state: RootState,
|
||||||
|
initialImage: ImageDTO
|
||||||
|
): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -307,6 +316,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
addCoreMetadataNode(
|
addCoreMetadataNode(
|
||||||
graph,
|
graph,
|
||||||
{
|
{
|
||||||
@ -317,7 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
|||||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
model,
|
model: getModelMetadataField(modelConfig),
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
@ -338,24 +349,24 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage:
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -41,11 +41,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
|
|||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Inpaint graph.
|
* Builds the Canvas tab's Inpaint graph.
|
||||||
*/
|
*/
|
||||||
export const buildCanvasSDXLInpaintGraph = (
|
export const buildCanvasSDXLInpaintGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
canvasInitImage: ImageDTO,
|
canvasInitImage: ImageDTO,
|
||||||
canvasMaskImage: ImageDTO
|
canvasMaskImage: ImageDTO
|
||||||
): NonNullableGraph => {
|
): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -426,24 +426,31 @@ export const buildCanvasSDXLInpaintGraph = (
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage, canvasMaskImage);
|
await addSDXLRefinerToGraph(
|
||||||
|
state,
|
||||||
|
graph,
|
||||||
|
SDXL_DENOISE_LATENTS,
|
||||||
|
modelLoaderNodeId,
|
||||||
|
canvasInitImage,
|
||||||
|
canvasMaskImage
|
||||||
|
);
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add VAE
|
// Add VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -45,11 +45,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu
|
|||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Outpaint graph.
|
* Builds the Canvas tab's Outpaint graph.
|
||||||
*/
|
*/
|
||||||
export const buildCanvasSDXLOutpaintGraph = (
|
export const buildCanvasSDXLOutpaintGraph = async (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
canvasInitImage: ImageDTO,
|
canvasInitImage: ImageDTO,
|
||||||
canvasMaskImage?: ImageDTO
|
canvasMaskImage?: ImageDTO
|
||||||
): NonNullableGraph => {
|
): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -555,25 +555,25 @@ export const buildCanvasSDXLOutpaintGraph = (
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
|
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage);
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add VAE
|
// Add VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import type { NonNullableGraph } from 'services/api/types';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||||
@ -24,12 +25,12 @@ import {
|
|||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||||
import { addCoreMetadataNode } from './metadata';
|
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
*/
|
*/
|
||||||
export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
|
export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -272,6 +273,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
addCoreMetadataNode(
|
addCoreMetadataNode(
|
||||||
graph,
|
graph,
|
||||||
{
|
{
|
||||||
@ -284,7 +287,7 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
|||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
positive_style_prompt: positiveStylePrompt,
|
positive_style_prompt: positiveStylePrompt,
|
||||||
negative_style_prompt: negativeStylePrompt,
|
negative_style_prompt: negativeStylePrompt,
|
||||||
model,
|
model: getModelMetadataField(modelConfig),
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
@ -301,24 +304,24 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import type { NonNullableGraph } from 'services/api/types';
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||||
@ -23,12 +24,12 @@ import {
|
|||||||
POSITIVE_CONDITIONING,
|
POSITIVE_CONDITIONING,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addCoreMetadataNode } from './metadata';
|
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Canvas tab's Text to Image graph.
|
* Builds the Canvas tab's Text to Image graph.
|
||||||
*/
|
*/
|
||||||
export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph => {
|
export const buildCanvasTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -262,6 +263,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
addCoreMetadataNode(
|
addCoreMetadataNode(
|
||||||
graph,
|
graph,
|
||||||
{
|
{
|
||||||
@ -272,7 +275,7 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
|||||||
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
model,
|
model: getModelMetadataField(modelConfig),
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
@ -289,17 +292,17 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -1,7 +1,13 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
import {
|
||||||
|
type ImageResizeInvocation,
|
||||||
|
type ImageToLatentsInvocation,
|
||||||
|
isNonRefinerMainModelConfig,
|
||||||
|
type NonNullableGraph,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||||
@ -24,12 +30,12 @@ import {
|
|||||||
RESIZE,
|
RESIZE,
|
||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addCoreMetadataNode } from './metadata';
|
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Image to Image tab graph.
|
* Builds the Image to Image tab graph.
|
||||||
*/
|
*/
|
||||||
export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph => {
|
export const buildLinearImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -307,6 +313,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
addCoreMetadataNode(
|
addCoreMetadataNode(
|
||||||
graph,
|
graph,
|
||||||
{
|
{
|
||||||
@ -317,7 +325,7 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
|||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
model,
|
model: getModelMetadataField(modelConfig),
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
@ -336,17 +344,17 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import {
|
||||||
|
type ImageResizeInvocation,
|
||||||
|
type ImageToLatentsInvocation,
|
||||||
|
isNonRefinerMainModelConfig,
|
||||||
|
type NonNullableGraph,
|
||||||
|
} from 'services/api/types';
|
||||||
|
|
||||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||||
@ -25,12 +31,12 @@ import {
|
|||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||||
import { addCoreMetadataNode } from './metadata';
|
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Image to Image tab graph.
|
* Builds the Image to Image tab graph.
|
||||||
*/
|
*/
|
||||||
export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableGraph => {
|
export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -318,6 +324,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
addCoreMetadataNode(
|
addCoreMetadataNode(
|
||||||
graph,
|
graph,
|
||||||
{
|
{
|
||||||
@ -328,7 +336,7 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
|||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
model,
|
model: getModelMetadataField(modelConfig),
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
@ -349,25 +357,25 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// Add LoRA Support
|
// Add LoRA Support
|
||||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// Add IP Adapter
|
// Add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import type { NonNullableGraph } from 'services/api/types';
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||||
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
|
||||||
@ -23,9 +24,9 @@ import {
|
|||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils';
|
||||||
import { addCoreMetadataNode } from './metadata';
|
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => {
|
export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -221,6 +222,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
addCoreMetadataNode(
|
addCoreMetadataNode(
|
||||||
graph,
|
graph,
|
||||||
{
|
{
|
||||||
@ -231,7 +234,7 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
|||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
model,
|
model: getModelMetadataField(modelConfig),
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
@ -250,25 +253,25 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr
|
|||||||
|
|
||||||
// Add Refiner if enabled
|
// Add Refiner if enabled
|
||||||
if (refinerModel) {
|
if (refinerModel) {
|
||||||
addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
if (seamlessXAxis || seamlessYAxis) {
|
if (seamlessXAxis || seamlessYAxis) {
|
||||||
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
modelLoaderNodeId = SDXL_REFINER_SEAMLESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// add IP Adapter
|
// add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
|
||||||
|
|
||||||
// NSFW & watermark - must be last thing added to graph
|
// NSFW & watermark - must be last thing added to graph
|
||||||
if (state.system.shouldUseNSFWChecker) {
|
if (state.system.shouldUseNSFWChecker) {
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
|
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||||
import type { NonNullableGraph } from 'services/api/types';
|
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
|
||||||
import { addHrfToGraph } from './addHrfToGraph';
|
import { addHrfToGraph } from './addHrfToGraph';
|
||||||
@ -23,9 +24,9 @@ import {
|
|||||||
SEAMLESS,
|
SEAMLESS,
|
||||||
TEXT_TO_IMAGE_GRAPH,
|
TEXT_TO_IMAGE_GRAPH,
|
||||||
} from './constants';
|
} from './constants';
|
||||||
import { addCoreMetadataNode } from './metadata';
|
import { addCoreMetadataNode, getModelMetadataField } from './metadata';
|
||||||
|
|
||||||
export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph => {
|
export const buildLinearTextToImageGraph = async (state: RootState): Promise<NonNullableGraph> => {
|
||||||
const log = logger('nodes');
|
const log = logger('nodes');
|
||||||
const {
|
const {
|
||||||
positivePrompt,
|
positivePrompt,
|
||||||
@ -212,6 +213,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||||
|
|
||||||
addCoreMetadataNode(
|
addCoreMetadataNode(
|
||||||
graph,
|
graph,
|
||||||
{
|
{
|
||||||
@ -222,7 +225,7 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
|||||||
width,
|
width,
|
||||||
positive_prompt: positivePrompt,
|
positive_prompt: positivePrompt,
|
||||||
negative_prompt: negativePrompt,
|
negative_prompt: negativePrompt,
|
||||||
model,
|
model: getModelMetadataField(modelConfig),
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
rand_device: use_cpu ? 'cpu' : 'cuda',
|
rand_device: use_cpu ? 'cpu' : 'cuda',
|
||||||
@ -239,18 +242,18 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
// optionally add custom VAE
|
// optionally add custom VAE
|
||||||
addVAEToGraph(state, graph, modelLoaderNodeId);
|
await addVAEToGraph(state, graph, modelLoaderNodeId);
|
||||||
|
|
||||||
// add LoRA support
|
// add LoRA support
|
||||||
addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId);
|
||||||
|
|
||||||
// add controlnet, mutating `graph`
|
// add controlnet, mutating `graph`
|
||||||
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// add IP Adapter
|
// add IP Adapter
|
||||||
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
|
||||||
|
|
||||||
// High resolution fix.
|
// High resolution fix.
|
||||||
if (state.hrf.hrfEnabled) {
|
if (state.hrf.hrfEnabled) {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import type { JSONObject } from 'common/types';
|
import type { JSONObject } from 'common/types';
|
||||||
import type { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types';
|
import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types';
|
||||||
|
|
||||||
import { METADATA } from './constants';
|
import { METADATA } from './constants';
|
||||||
|
|
||||||
@ -71,3 +71,11 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({
|
||||||
|
key,
|
||||||
|
hash,
|
||||||
|
name,
|
||||||
|
base,
|
||||||
|
type,
|
||||||
|
});
|
||||||
|
@ -38,7 +38,6 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
|
|||||||
export type ModelType = S['ModelType'];
|
export type ModelType = S['ModelType'];
|
||||||
export type SubModelType = S['SubModelType'];
|
export type SubModelType = S['SubModelType'];
|
||||||
export type BaseModelType = S['BaseModelType'];
|
export type BaseModelType = S['BaseModelType'];
|
||||||
export type ControlField = S['ControlField'];
|
|
||||||
|
|
||||||
// Model Configs
|
// Model Configs
|
||||||
|
|
||||||
@ -129,8 +128,9 @@ export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation'];
|
|||||||
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
|
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
|
||||||
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
|
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
|
||||||
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];
|
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];
|
||||||
export type IPAdapterMetadataField = S['IPAdapterMetadataField'];
|
|
||||||
export type T2IAdapterField = S['T2IAdapterField'];
|
// Metadata fields
|
||||||
|
export type ModelMetadataField = S['ModelMetadataField'];
|
||||||
|
|
||||||
// ControlNet Nodes
|
// ControlNet Nodes
|
||||||
export type ControlNetInvocation = S['ControlNetInvocation'];
|
export type ControlNetInvocation = S['ControlNetInvocation'];
|
||||||
|
Loading…
x
Reference in New Issue
Block a user