diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts index ed1f4fdd98..f38020b8ea 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts @@ -101,7 +101,7 @@ export const addEnqueueRequestedCanvasListener = (startAppListening: AppStartLis ).unwrap(); } - const graph = buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage); + const graph = await buildCanvasGraph(state, generationMode, canvasInitImage, canvasMaskImage); log.debug({ graph: parseify(graph) }, `Canvas graph built`); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index 337c0f4145..f923edb99a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -20,15 +20,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) if (model && model.base === 'sdxl') { if (action.payload.tabName === 'txt2img') { - graph = buildLinearSDXLTextToImageGraph(state); + graph = await buildLinearSDXLTextToImageGraph(state); } else { - graph = buildLinearSDXLImageToImageGraph(state); + graph = await buildLinearSDXLImageToImageGraph(state); } } else { if (action.payload.tabName === 'txt2img') { - graph = buildLinearTextToImageGraph(state); + graph = await buildLinearTextToImageGraph(state); } else { - graph = buildLinearImageToImageGraph(state); + graph = await buildLinearImageToImageGraph(state); } } diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index f665c7780d..cbbe150ed4 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -55,8 +55,22 @@ export type SchedulerField = z.infer; // #region Model-related schemas const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); + const zModelIdentifier = z.object({ key: z.string().min(1), + submodel_type: zSubModelType.nullish(), }); export const isModelIdentifier = (field: unknown): field is ModelIdentifier => zModelIdentifier.safeParse(field).success; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts index 8441a1ef21..12a163dd79 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts @@ -1,18 +1,22 @@ import type { RootState } from 'app/store/store'; import { selectValidControlNets } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { omit } from 'lodash-es'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import type { CollectInvocation, - ControlField, ControlNetInvocation, CoreMetadataInvocation, NonNullableGraph, } from 'services/api/types'; +import { isControlNetModelConfig } from 'services/api/types'; import { CONTROL_NET_COLLECT } from './constants'; -import { upsertMetadata } from './metadata'; +import { getModelMetadataField, upsertMetadata } from './metadata'; -export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { +export const addControlNetToLinearGraph = async ( + state: RootState, + graph: NonNullableGraph, + baseNodeId: string +): Promise => { const validControlNets = selectValidControlNets(state.controlAdapters).filter( (ca) => ca.model?.base === state.generation.model?.base ); @@ -39,7 +43,7 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG }, }); - validControlNets.forEach((controlNet) => { + validControlNets.forEach(async (controlNet) => { if (!controlNet.model) { return; } @@ -85,7 +89,17 @@ export const addControlNetToLinearGraph = (state: RootState, graph: NonNullableG graph.nodes[controlNetNode.id] = controlNetNode as ControlNetInvocation; - controlNetMetadata.push(omit(controlNetNode, ['id', 'type', 'is_intermediate']) as ControlField); + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isControlNetModelConfig); + + controlNetMetadata.push({ + control_model: getModelMetadataField(modelConfig), + control_weight: weight, + control_mode: controlMode, + begin_step_percent: beginStepPct, + end_step_percent: endStepPct, + resize_mode: resizeMode, + image: controlNetNode.image, + }); graph.edges.push({ source: { node_id: controlNetNode.id, field: 'control' }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts index 1760310cae..7328b8f868 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts @@ -1,18 +1,22 @@ import type { RootState } from 'app/store/store'; import { selectValidIPAdapters } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { omit } from 'lodash-es'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import type { CollectInvocation, CoreMetadataInvocation, IPAdapterInvocation, - IPAdapterMetadataField, NonNullableGraph, } from 'services/api/types'; +import { isIPAdapterModelConfig } from 'services/api/types'; import { IP_ADAPTER_COLLECT } from './constants'; -import { upsertMetadata } from './metadata'; +import { getModelMetadataField, upsertMetadata } from './metadata'; -export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { +export const addIPAdapterToLinearGraph = async ( + state: RootState, + graph: NonNullableGraph, + baseNodeId: string +): Promise => { const validIPAdapters = selectValidIPAdapters(state.controlAdapters).filter( (ca) => ca.model?.base === state.generation.model?.base ); @@ -35,7 +39,7 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr const ipAdapterMetdata: CoreMetadataInvocation['ipAdapters'] = []; - validIPAdapters.forEach((ipAdapter) => { + validIPAdapters.forEach(async (ipAdapter) => { if (!ipAdapter.model) { return; } @@ -58,9 +62,17 @@ export const addIPAdapterToLinearGraph = (state: RootState, graph: NonNullableGr return; } - graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation; + graph.nodes[ipAdapterNode.id] = ipAdapterNode; - ipAdapterMetdata.push(omit(ipAdapterNode, ['id', 'type', 'is_intermediate']) as IPAdapterMetadataField); + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isIPAdapterModelConfig); + + ipAdapterMetdata.push({ + weight: weight, + ip_adapter_model: getModelMetadataField(modelConfig), + begin_step_percent: beginStepPct, + end_step_percent: endStepPct, + image: ipAdapterNode.image, + }); graph.edges.push({ source: { node_id: ipAdapterNode.id, field: 'ip_adapter' }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts index c49a0e19e3..f88312f0de 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts @@ -1,16 +1,22 @@ import type { RootState } from 'app/store/store'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { filter, size } from 'lodash-es'; -import type { CoreMetadataInvocation, LoraLoaderInvocation, NonNullableGraph } from 'services/api/types'; +import { + type CoreMetadataInvocation, + isLoRAModelConfig, + type LoraLoaderInvocation, + type NonNullableGraph, +} from 'services/api/types'; import { CLIP_SKIP, LORA_LOADER, MAIN_MODEL_LOADER, NEGATIVE_CONDITIONING, POSITIVE_CONDITIONING } from './constants'; -import { upsertMetadata } from './metadata'; +import { getModelMetadataField, upsertMetadata } from './metadata'; -export const addLoRAsToGraph = ( +export const addLoRAsToGraph = async ( state: RootState, graph: NonNullableGraph, baseNodeId: string, modelLoaderNodeId: string = MAIN_MODEL_LOADER -): void => { +): Promise => { /** * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. * They then output the UNet and CLIP models references on to either the next LoRA in the chain, @@ -39,7 +45,7 @@ export const addLoRAsToGraph = ( let currentLoraIndex = 0; const loraMetadata: CoreMetadataInvocation['loras'] = []; - enabledLoRAs.forEach((lora) => { + enabledLoRAs.forEach(async (lora) => { const { weight } = lora; const { key } = lora.model; const currentLoraNodeId = `${LORA_LOADER}_${key}`; @@ -52,8 +58,10 @@ export const addLoRAsToGraph = ( weight, }; + const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); + loraMetadata.push({ - model: { key }, + model: getModelMetadataField(modelConfig), weight, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts index 36223391fe..e150eaface 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts @@ -1,6 +1,12 @@ import type { RootState } from 'app/store/store'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { filter, size } from 'lodash-es'; -import type { CoreMetadataInvocation, NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types'; +import { + type CoreMetadataInvocation, + isLoRAModelConfig, + type NonNullableGraph, + type SDXLLoraLoaderInvocation, +} from 'services/api/types'; import { LORA_LOADER, @@ -10,14 +16,14 @@ import { SDXL_REFINER_INPAINT_CREATE_MASK, SEAMLESS, } from './constants'; -import { upsertMetadata } from './metadata'; +import { getModelMetadataField, upsertMetadata } from './metadata'; -export const addSDXLLoRAsToGraph = ( +export const addSDXLLoRAsToGraph = async ( state: RootState, graph: NonNullableGraph, baseNodeId: string, modelLoaderNodeId: string = SDXL_MODEL_LOADER -): void => { +): Promise => { /** * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. * They then output the UNet and CLIP models references on to either the next LoRA in the chain, @@ -55,7 +61,7 @@ export const addSDXLLoRAsToGraph = ( let lastLoraNodeId = ''; let currentLoraIndex = 0; - enabledLoRAs.forEach((lora) => { + enabledLoRAs.forEach(async (lora) => { const { weight } = lora; const { key } = lora.model; const currentLoraNodeId = `${LORA_LOADER}_${key}`; @@ -68,7 +74,9 @@ export const addSDXLLoRAsToGraph = ( weight, }; - loraMetadata.push({ model: { key }, weight }); + const modelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig); + + loraMetadata.push({ model: getModelMetadataField(modelConfig), weight }); // add to graph graph.nodes[currentLoraNodeId] = loraLoaderNode; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts index fc4d998969..24a83d9d19 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts @@ -1,9 +1,11 @@ import type { RootState } from 'app/store/store'; -import type { - CreateDenoiseMaskInvocation, - ImageDTO, - NonNullableGraph, - SeamlessModeInvocation, +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { + type CreateDenoiseMaskInvocation, + type ImageDTO, + isRefinerMainModelModelConfig, + type NonNullableGraph, + type SeamlessModeInvocation, } from 'services/api/types'; import { @@ -25,16 +27,16 @@ import { SDXL_REFINER_SEAMLESS, } from './constants'; import { getSDXLStylePrompts } from './graphBuilderUtils'; -import { upsertMetadata } from './metadata'; +import { getModelMetadataField, upsertMetadata } from './metadata'; -export const addSDXLRefinerToGraph = ( +export const addSDXLRefinerToGraph = async ( state: RootState, graph: NonNullableGraph, baseNodeId: string, modelLoaderNodeId?: string, canvasInitImage?: ImageDTO, canvasMaskImage?: ImageDTO -): void => { +): Promise => { const { refinerModel, refinerPositiveAestheticScore, @@ -55,9 +57,10 @@ export const addSDXLRefinerToGraph = ( const fp32 = vaePrecision === 'fp32'; const isUsingScaledDimensions = ['auto', 'manual'].includes(boundingBoxScaleMethod); + const modelConfig = await fetchModelConfigWithTypeGuard(refinerModel.key, isRefinerMainModelModelConfig); upsertMetadata(graph, { - refiner_model: refinerModel, + refiner_model: getModelMetadataField(modelConfig), refiner_positive_aesthetic_score: refinerPositiveAestheticScore, refiner_negative_aesthetic_score: refinerNegativeAestheticScore, refiner_cfg_scale: refinerCFGScale, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts index a765a72a9f..b2d2488970 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts @@ -1,18 +1,22 @@ import type { RootState } from 'app/store/store'; import { selectValidT2IAdapters } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { omit } from 'lodash-es'; -import type { - CollectInvocation, - CoreMetadataInvocation, - NonNullableGraph, - T2IAdapterField, - T2IAdapterInvocation, +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { + type CollectInvocation, + type CoreMetadataInvocation, + isT2IAdapterModelConfig, + type NonNullableGraph, + type T2IAdapterInvocation, } from 'services/api/types'; import { T2I_ADAPTER_COLLECT } from './constants'; -import { upsertMetadata } from './metadata'; +import { getModelMetadataField, upsertMetadata } from './metadata'; -export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullableGraph, baseNodeId: string): void => { +export const addT2IAdaptersToLinearGraph = async ( + state: RootState, + graph: NonNullableGraph, + baseNodeId: string +): Promise => { const validT2IAdapters = selectValidT2IAdapters(state.controlAdapters).filter( (ca) => ca.model?.base === state.generation.model?.base ); @@ -35,7 +39,7 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable const t2iAdapterMetdata: CoreMetadataInvocation['t2iAdapters'] = []; - validT2IAdapters.forEach((t2iAdapter) => { + validT2IAdapters.forEach(async (t2iAdapter) => { if (!t2iAdapter.model) { return; } @@ -77,9 +81,18 @@ export const addT2IAdaptersToLinearGraph = (state: RootState, graph: NonNullable return; } - graph.nodes[t2iAdapterNode.id] = t2iAdapterNode as T2IAdapterInvocation; + graph.nodes[t2iAdapterNode.id] = t2iAdapterNode; - t2iAdapterMetdata.push(omit(t2iAdapterNode, ['id', 'type', 'is_intermediate']) as T2IAdapterField); + const modelConfig = await fetchModelConfigWithTypeGuard(t2iAdapter.model.key, isT2IAdapterModelConfig); + + t2iAdapterMetdata.push({ + begin_step_percent: beginStepPct, + end_step_percent: endStepPct, + resize_mode: resizeMode, + t2i_adapter_model: getModelMetadataField(modelConfig), + weight: weight, + image: t2iAdapterNode.image, + }); graph.edges.push({ source: { node_id: t2iAdapterNode.id, field: 't2i_adapter' }, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts index 1cda6948a1..58c92eddd7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts @@ -1,5 +1,7 @@ import type { RootState } from 'app/store/store'; -import type { NonNullableGraph } from 'services/api/types'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import type { ModelMetadataField, NonNullableGraph } from 'services/api/types'; +import { isVAEModelConfig } from 'services/api/types'; import { CANVAS_IMAGE_TO_IMAGE_GRAPH, @@ -23,13 +25,13 @@ import { TEXT_TO_IMAGE_GRAPH, VAE_LOADER, } from './constants'; -import { upsertMetadata } from './metadata'; +import { getModelMetadataField, upsertMetadata } from './metadata'; -export const addVAEToGraph = ( +export const addVAEToGraph = async ( state: RootState, graph: NonNullableGraph, modelLoaderNodeId: string = MAIN_MODEL_LOADER -): void => { +): Promise => { const { vae, seamlessXAxis, seamlessYAxis } = state.generation; const { boundingBoxScaleMethod } = state.canvas; const { refinerModel } = state.sdxl; @@ -149,6 +151,8 @@ export const addVAEToGraph = ( } if (vae) { - upsertMetadata(graph, { vae }); + const modelConfig = await fetchModelConfigWithTypeGuard(vae.key, isVAEModelConfig); + const vaeMetadata: ModelMetadataField = getModelMetadataField(modelConfig); + upsertMetadata(graph, { vae: vaeMetadata }); } }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts index 4ce2e4d673..abf8b88773 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts @@ -10,46 +10,46 @@ import { buildCanvasSDXLOutpaintGraph } from './buildCanvasSDXLOutpaintGraph'; import { buildCanvasSDXLTextToImageGraph } from './buildCanvasSDXLTextToImageGraph'; import { buildCanvasTextToImageGraph } from './buildCanvasTextToImageGraph'; -export const buildCanvasGraph = ( +export const buildCanvasGraph = async ( state: RootState, generationMode: 'txt2img' | 'img2img' | 'inpaint' | 'outpaint', canvasInitImage: ImageDTO | undefined, canvasMaskImage: ImageDTO | undefined -) => { +): Promise => { let graph: NonNullableGraph; if (generationMode === 'txt2img') { if (state.generation.model && state.generation.model.base === 'sdxl') { - graph = buildCanvasSDXLTextToImageGraph(state); + graph = await buildCanvasSDXLTextToImageGraph(state); } else { - graph = buildCanvasTextToImageGraph(state); + graph = await buildCanvasTextToImageGraph(state); } } else if (generationMode === 'img2img') { if (!canvasInitImage) { throw new Error('Missing canvas init image'); } if (state.generation.model && state.generation.model.base === 'sdxl') { - graph = buildCanvasSDXLImageToImageGraph(state, canvasInitImage); + graph = await buildCanvasSDXLImageToImageGraph(state, canvasInitImage); } else { - graph = buildCanvasImageToImageGraph(state, canvasInitImage); + graph = await buildCanvasImageToImageGraph(state, canvasInitImage); } } else if (generationMode === 'inpaint') { if (!canvasInitImage || !canvasMaskImage) { throw new Error('Missing canvas init and mask images'); } if (state.generation.model && state.generation.model.base === 'sdxl') { - graph = buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage); + graph = await buildCanvasSDXLInpaintGraph(state, canvasInitImage, canvasMaskImage); } else { - graph = buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage); + graph = await buildCanvasInpaintGraph(state, canvasInitImage, canvasMaskImage); } } else { if (!canvasInitImage) { throw new Error('Missing canvas init image'); } if (state.generation.model && state.generation.model.base === 'sdxl') { - graph = buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage); + graph = await buildCanvasSDXLOutpaintGraph(state, canvasInitImage, canvasMaskImage); } else { - graph = buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage); + graph = await buildCanvasOutpaintGraph(state, canvasInitImage, canvasMaskImage); } } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts index bc6a83f4fa..16c42cd111 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts @@ -1,7 +1,13 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; +import { + type ImageDTO, + type ImageToLatentsInvocation, + isNonRefinerMainModelConfig, + type NonNullableGraph, +} from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; @@ -25,12 +31,15 @@ import { POSITIVE_CONDITIONING, SEAMLESS, } from './constants'; -import { addCoreMetadataNode } from './metadata'; +import { addCoreMetadataNode, getModelMetadataField } from './metadata'; /** * Builds the Canvas tab's Image to Image graph. */ -export const buildCanvasImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => { +export const buildCanvasImageToImageGraph = async ( + state: RootState, + initialImage: ImageDTO +): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -306,6 +315,8 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima }); } + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + addCoreMetadataNode( graph, { @@ -316,7 +327,7 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height, positive_prompt: positivePrompt, negative_prompt: negativePrompt, - model, + model: getModelMetadataField(modelConfig), seed, steps, rand_device: use_cpu ? 'cpu' : 'cuda', @@ -335,17 +346,17 @@ export const buildCanvasImageToImageGraph = (state: RootState, initialImage: Ima } // add LoRA support - addLoRAsToGraph(state, graph, DENOISE_LATENTS); + await addLoRAsToGraph(state, graph, DENOISE_LATENTS); // optionally add custom VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts index 2672cf5be3..372f7d8fe5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts @@ -40,11 +40,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; /** * Builds the Canvas tab's Inpaint graph. */ -export const buildCanvasInpaintGraph = ( +export const buildCanvasInpaintGraph = async ( state: RootState, canvasInitImage: ImageDTO, canvasMaskImage: ImageDTO -): NonNullableGraph => { +): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -414,17 +414,17 @@ export const buildCanvasInpaintGraph = ( } // Add VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support - addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); + await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { // must add before watermarker! diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts index a9707e50f8..d847ccbfb5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts @@ -44,11 +44,11 @@ import { getBoardField, getIsIntermediate } from './graphBuilderUtils'; /** * Builds the Canvas tab's Outpaint graph. */ -export const buildCanvasOutpaintGraph = ( +export const buildCanvasOutpaintGraph = async ( state: RootState, canvasInitImage: ImageDTO, canvasMaskImage?: ImageDTO -): NonNullableGraph => { +): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -545,18 +545,18 @@ export const buildCanvasOutpaintGraph = ( } // Add VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support - addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); + await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts index 58269afce3..059003c34b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts @@ -1,6 +1,12 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; -import type { ImageDTO, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { + type ImageDTO, + type ImageToLatentsInvocation, + isNonRefinerMainModelConfig, + type NonNullableGraph, +} from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; @@ -26,12 +32,15 @@ import { SEAMLESS, } from './constants'; import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; -import { addCoreMetadataNode } from './metadata'; +import { addCoreMetadataNode, getModelMetadataField } from './metadata'; /** * Builds the Canvas tab's Image to Image graph. */ -export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: ImageDTO): NonNullableGraph => { +export const buildCanvasSDXLImageToImageGraph = async ( + state: RootState, + initialImage: ImageDTO +): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -307,6 +316,8 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: }); } + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + addCoreMetadataNode( graph, { @@ -317,7 +328,7 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height, positive_prompt: positivePrompt, negative_prompt: negativePrompt, - model, + model: getModelMetadataField(modelConfig), seed, steps, rand_device: use_cpu ? 'cpu' : 'cuda', @@ -338,24 +349,24 @@ export const buildCanvasSDXLImageToImageGraph = (state: RootState, initialImage: // Add Refiner if enabled if (refinerModel) { - addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); + await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); if (seamlessXAxis || seamlessYAxis) { modelLoaderNodeId = SDXL_REFINER_SEAMLESS; } } // optionally add custom VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support - addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); + await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts index 9f4e75de48..4d5ce616fa 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts @@ -41,11 +41,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu /** * Builds the Canvas tab's Inpaint graph. */ -export const buildCanvasSDXLInpaintGraph = ( +export const buildCanvasSDXLInpaintGraph = async ( state: RootState, canvasInitImage: ImageDTO, canvasMaskImage: ImageDTO -): NonNullableGraph => { +): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -426,24 +426,31 @@ export const buildCanvasSDXLInpaintGraph = ( // Add Refiner if enabled if (refinerModel) { - addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage, canvasMaskImage); + await addSDXLRefinerToGraph( + state, + graph, + SDXL_DENOISE_LATENTS, + modelLoaderNodeId, + canvasInitImage, + canvasMaskImage + ); if (seamlessXAxis || seamlessYAxis) { modelLoaderNodeId = SDXL_REFINER_SEAMLESS; } } // Add VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support - addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); + await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts index 6c5a31926a..39a54fd9d1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts @@ -45,11 +45,11 @@ import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBu /** * Builds the Canvas tab's Outpaint graph. */ -export const buildCanvasSDXLOutpaintGraph = ( +export const buildCanvasSDXLOutpaintGraph = async ( state: RootState, canvasInitImage: ImageDTO, canvasMaskImage?: ImageDTO -): NonNullableGraph => { +): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -555,25 +555,25 @@ export const buildCanvasSDXLOutpaintGraph = ( // Add Refiner if enabled if (refinerModel) { - addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage); + await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId, canvasInitImage); if (seamlessXAxis || seamlessYAxis) { modelLoaderNodeId = SDXL_REFINER_SEAMLESS; } } // Add VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support - addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); + await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts index 22da39c67d..b7e1ae80b0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; -import type { NonNullableGraph } from 'services/api/types'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; @@ -24,12 +25,12 @@ import { SEAMLESS, } from './constants'; import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; -import { addCoreMetadataNode } from './metadata'; +import { addCoreMetadataNode, getModelMetadataField } from './metadata'; /** * Builds the Canvas tab's Text to Image graph. */ -export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGraph => { +export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -272,6 +273,8 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr }); } + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + addCoreMetadataNode( graph, { @@ -284,7 +287,7 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr negative_prompt: negativePrompt, positive_style_prompt: positiveStylePrompt, negative_style_prompt: negativeStylePrompt, - model, + model: getModelMetadataField(modelConfig), seed, steps, rand_device: use_cpu ? 'cpu' : 'cuda', @@ -301,24 +304,24 @@ export const buildCanvasSDXLTextToImageGraph = (state: RootState): NonNullableGr // Add Refiner if enabled if (refinerModel) { - addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); + await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); if (seamlessXAxis || seamlessYAxis) { modelLoaderNodeId = SDXL_REFINER_SEAMLESS; } } // add LoRA support - addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); + await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); // optionally add custom VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts index 93f0470c7a..c14da86e3e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts @@ -1,7 +1,8 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { NonNullableGraph } from 'services/api/types'; +import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; @@ -23,12 +24,12 @@ import { POSITIVE_CONDITIONING, SEAMLESS, } from './constants'; -import { addCoreMetadataNode } from './metadata'; +import { addCoreMetadataNode, getModelMetadataField } from './metadata'; /** * Builds the Canvas tab's Text to Image graph. */ -export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph => { +export const buildCanvasTextToImageGraph = async (state: RootState): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -262,6 +263,8 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph }); } + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + addCoreMetadataNode( graph, { @@ -272,7 +275,7 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph height: !isUsingScaledDimensions ? height : scaledBoundingBoxDimensions.height, positive_prompt: positivePrompt, negative_prompt: negativePrompt, - model, + model: getModelMetadataField(modelConfig), seed, steps, rand_device: use_cpu ? 'cpu' : 'cuda', @@ -289,17 +292,17 @@ export const buildCanvasTextToImageGraph = (state: RootState): NonNullableGraph } // optionally add custom VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support - addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); + await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts index d1f1546b23..120afb98ee 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts @@ -1,7 +1,13 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; +import { + type ImageResizeInvocation, + type ImageToLatentsInvocation, + isNonRefinerMainModelConfig, + type NonNullableGraph, +} from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; @@ -24,12 +30,12 @@ import { RESIZE, SEAMLESS, } from './constants'; -import { addCoreMetadataNode } from './metadata'; +import { addCoreMetadataNode, getModelMetadataField } from './metadata'; /** * Builds the Image to Image tab graph. */ -export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph => { +export const buildLinearImageToImageGraph = async (state: RootState): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -307,6 +313,8 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph }); } + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + addCoreMetadataNode( graph, { @@ -317,7 +325,7 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph width, positive_prompt: positivePrompt, negative_prompt: negativePrompt, - model, + model: getModelMetadataField(modelConfig), seed, steps, rand_device: use_cpu ? 'cpu' : 'cuda', @@ -336,17 +344,17 @@ export const buildLinearImageToImageGraph = (state: RootState): NonNullableGraph } // optionally add custom VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support - addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); + await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts index de4ad7cece..bd6902725b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts @@ -1,6 +1,12 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; -import type { ImageResizeInvocation, ImageToLatentsInvocation, NonNullableGraph } from 'services/api/types'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { + type ImageResizeInvocation, + type ImageToLatentsInvocation, + isNonRefinerMainModelConfig, + type NonNullableGraph, +} from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; @@ -25,12 +31,12 @@ import { SEAMLESS, } from './constants'; import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; -import { addCoreMetadataNode } from './metadata'; +import { addCoreMetadataNode, getModelMetadataField } from './metadata'; /** * Builds the Image to Image tab graph. */ -export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableGraph => { +export const buildLinearSDXLImageToImageGraph = async (state: RootState): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -318,6 +324,8 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG }); } + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + addCoreMetadataNode( graph, { @@ -328,7 +336,7 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG width, positive_prompt: positivePrompt, negative_prompt: negativePrompt, - model, + model: getModelMetadataField(modelConfig), seed, steps, rand_device: use_cpu ? 'cpu' : 'cuda', @@ -349,25 +357,25 @@ export const buildLinearSDXLImageToImageGraph = (state: RootState): NonNullableG // Add Refiner if enabled if (refinerModel) { - addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); + await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); if (seamlessXAxis || seamlessYAxis) { modelLoaderNodeId = SDXL_REFINER_SEAMLESS; } } // optionally add custom VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // Add LoRA Support - addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); + await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // Add IP Adapter - addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts index 58b97b07c7..7dadb8c3e9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts @@ -1,6 +1,7 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; -import type { NonNullableGraph } from 'services/api/types'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; @@ -23,9 +24,9 @@ import { SEAMLESS, } from './constants'; import { getBoardField, getIsIntermediate, getSDXLStylePrompts } from './graphBuilderUtils'; -import { addCoreMetadataNode } from './metadata'; +import { addCoreMetadataNode, getModelMetadataField } from './metadata'; -export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGraph => { +export const buildLinearSDXLTextToImageGraph = async (state: RootState): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -221,6 +222,8 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr ], }; + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + addCoreMetadataNode( graph, { @@ -231,7 +234,7 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr width, positive_prompt: positivePrompt, negative_prompt: negativePrompt, - model, + model: getModelMetadataField(modelConfig), seed, steps, rand_device: use_cpu ? 'cpu' : 'cuda', @@ -250,25 +253,25 @@ export const buildLinearSDXLTextToImageGraph = (state: RootState): NonNullableGr // Add Refiner if enabled if (refinerModel) { - addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); + await addSDXLRefinerToGraph(state, graph, SDXL_DENOISE_LATENTS); if (seamlessXAxis || seamlessYAxis) { modelLoaderNodeId = SDXL_REFINER_SEAMLESS; } } // optionally add custom VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support - addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); + await addSDXLLoRAsToGraph(state, graph, SDXL_DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // add IP Adapter - addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS); // NSFW & watermark - must be last thing added to graph if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts index b2b84cfdad..aac1270e0d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts @@ -1,7 +1,8 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; import { getBoardField, getIsIntermediate } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { NonNullableGraph } from 'services/api/types'; +import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addHrfToGraph } from './addHrfToGraph'; @@ -23,9 +24,9 @@ import { SEAMLESS, TEXT_TO_IMAGE_GRAPH, } from './constants'; -import { addCoreMetadataNode } from './metadata'; +import { addCoreMetadataNode, getModelMetadataField } from './metadata'; -export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph => { +export const buildLinearTextToImageGraph = async (state: RootState): Promise => { const log = logger('nodes'); const { positivePrompt, @@ -212,6 +213,8 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph ], }; + const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); + addCoreMetadataNode( graph, { @@ -222,7 +225,7 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph width, positive_prompt: positivePrompt, negative_prompt: negativePrompt, - model, + model: getModelMetadataField(modelConfig), seed, steps, rand_device: use_cpu ? 'cpu' : 'cuda', @@ -239,18 +242,18 @@ export const buildLinearTextToImageGraph = (state: RootState): NonNullableGraph } // optionally add custom VAE - addVAEToGraph(state, graph, modelLoaderNodeId); + await addVAEToGraph(state, graph, modelLoaderNodeId); // add LoRA support - addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); + await addLoRAsToGraph(state, graph, DENOISE_LATENTS, modelLoaderNodeId); // add controlnet, mutating `graph` - addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); + await addControlNetToLinearGraph(state, graph, DENOISE_LATENTS); // add IP Adapter - addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); + await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS); - addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); + await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS); // High resolution fix. if (state.hrf.hrfEnabled) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts b/invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts index c48f54d191..3a87b30fd7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts @@ -1,5 +1,5 @@ import type { JSONObject } from 'common/types'; -import type { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types'; +import type { AnyModelConfig, CoreMetadataInvocation, ModelMetadataField, NonNullableGraph } from 'services/api/types'; import { METADATA } from './constants'; @@ -71,3 +71,11 @@ export const setMetadataReceivingNode = (graph: NonNullableGraph, nodeId: string }, }); }; + +export const getModelMetadataField = ({ key, hash, name, base, type }: AnyModelConfig): ModelMetadataField => ({ + key, + hash, + name, + base, + type, +}); diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 3adc140383..b12f2aebe8 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -38,7 +38,6 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT export type ModelType = S['ModelType']; export type SubModelType = S['SubModelType']; export type BaseModelType = S['BaseModelType']; -export type ControlField = S['ControlField']; // Model Configs @@ -129,8 +128,9 @@ export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation']; export type ImageWatermarkInvocation = S['ImageWatermarkInvocation']; export type SeamlessModeInvocation = S['SeamlessModeInvocation']; export type CoreMetadataInvocation = S['CoreMetadataInvocation']; -export type IPAdapterMetadataField = S['IPAdapterMetadataField']; -export type T2IAdapterField = S['T2IAdapterField']; + +// Metadata fields +export type ModelMetadataField = S['ModelMetadataField']; // ControlNet Nodes export type ControlNetInvocation = S['ControlNetInvocation'];