From 23ad6fb730013c40c83c87393ac191f74ff363c2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 8 May 2024 15:49:19 +1000 Subject: [PATCH] feat(ui): handle missing images/models when recalling control layers --- .../src/features/metadata/util/recallers.ts | 89 ++++++++++++++++--- .../src/features/metadata/util/validators.ts | 11 +-- 2 files changed, 79 insertions(+), 21 deletions(-) diff --git a/invokeai/frontend/web/src/features/metadata/util/recallers.ts b/invokeai/frontend/web/src/features/metadata/util/recallers.ts index 09c405c3d6..673e2187f2 100644 --- a/invokeai/frontend/web/src/features/metadata/util/recallers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/recallers.ts @@ -1,4 +1,5 @@ import { getStore } from 'app/store/nanostores/store'; +import { deepClone } from 'common/util/deepClone'; import { controlAdapterRecalled, controlNetsReset, @@ -28,6 +29,7 @@ import type { MetadataRecallFunc, T2IAdapterConfigMetadata, } from 'features/metadata/types'; +import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers'; import { modelSelected } from 'features/parameters/store/actions'; import { setCfgRescaleMultiplier, @@ -69,6 +71,7 @@ import { setRefinerStart, setRefinerSteps, } from 'features/sdxl/store/sdxlSlice'; +import { getImageDTO } from 'services/api/endpoints/images'; const recallPositivePrompt: MetadataRecallFunc = (positivePrompt) => { getStore().dispatch(positivePromptChanged(positivePrompt)); @@ -237,21 +240,79 @@ const recallIPAdapters: MetadataRecallFunc = (ipAdapt }; //#region Control Layers -const recallLayer: MetadataRecallFunc = (layer) => { +const recallLayer: MetadataRecallFunc = async (layer) => { const { dispatch } = getStore(); - switch (layer.type) { - case 'control_adapter_layer': - dispatch(caLayerRecalled(layer)); - break; - case 'ip_adapter_layer': - dispatch(ipaLayerRecalled(layer)); - break; - case 'regional_guidance_layer': - dispatch(rgLayerRecalled(layer)); - break; - case 'initial_image_layer': - dispatch(iiLayerRecalled(layer)); - break; + // We need to check for the existence of all images and models when recalling. If they do not exist, SMITE THEM! + if (layer.type === 'control_adapter_layer') { + const clone = deepClone(layer); + if (clone.controlAdapter.image) { + const imageDTO = await getImageDTO(clone.controlAdapter.image.name); + if (!imageDTO) { + clone.controlAdapter.image = null; + } + } + if (clone.controlAdapter.processedImage) { + const imageDTO = await getImageDTO(clone.controlAdapter.processedImage.name); + if (!imageDTO) { + clone.controlAdapter.processedImage = null; + } + } + if (clone.controlAdapter.model) { + try { + await fetchModelConfigByIdentifier(clone.controlAdapter.model); + } catch { + clone.controlAdapter.model = null; + } + } + dispatch(caLayerRecalled(clone)); + return; + } + if (layer.type === 'ip_adapter_layer') { + const clone = deepClone(layer); + if (clone.ipAdapter.image) { + const imageDTO = await getImageDTO(clone.ipAdapter.image.name); + if (!imageDTO) { + clone.ipAdapter.image = null; + } + } + if (clone.ipAdapter.model) { + try { + await fetchModelConfigByIdentifier(clone.ipAdapter.model); + } catch { + clone.ipAdapter.model = null; + } + } + dispatch(ipaLayerRecalled(clone)); + return; + } + + if (layer.type === 'regional_guidance_layer') { + const clone = deepClone(layer); + // Strip out the uploaded mask image property - this is an intermediate image + clone.uploadedMaskImage = null; + + for (const ipAdapter of clone.ipAdapters) { + if (ipAdapter.image) { + const imageDTO = await getImageDTO(ipAdapter.image.name); + if (!imageDTO) { + ipAdapter.image = null; + } + } + if (ipAdapter.model) { + try { + await fetchModelConfigByIdentifier(ipAdapter.model); + } catch { + ipAdapter.model = null; + } + } + } + dispatch(rgLayerRecalled(clone)); + return; + } + + if (layer.type === 'initial_image_layer') { + dispatch(iiLayerRecalled(layer)); + return; } }; diff --git a/invokeai/frontend/web/src/features/metadata/util/validators.ts b/invokeai/frontend/web/src/features/metadata/util/validators.ts index a308021a1e..759e8ba561 100644 --- a/invokeai/frontend/web/src/features/metadata/util/validators.ts +++ b/invokeai/frontend/web/src/features/metadata/util/validators.ts @@ -7,7 +7,7 @@ import type { MetadataValidateFunc, T2IAdapterConfigMetadata, } from 'features/metadata/types'; -import { fetchModelConfigByIdentifier, InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers'; +import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers'; import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import type { BaseModelType } from 'services/api/types'; import { assert } from 'tsafe'; @@ -115,32 +115,29 @@ const validateLayer: MetadataValidateFunc = async (layer) => { const model = layer.controlAdapter.model; assert(model, 'Control Adapter layer missing model'); validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); - fetchModelConfigByIdentifier(model); } if (layer.type === 'ip_adapter_layer') { const model = layer.ipAdapter.model; assert(model, 'IP Adapter layer missing model'); validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); - fetchModelConfigByIdentifier(model); } if (layer.type === 'regional_guidance_layer') { for (const ipa of layer.ipAdapters) { const model = ipa.model; assert(model, 'IP Adapter layer missing model'); validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); - fetchModelConfigByIdentifier(model); } } return layer; }; -const validateLayers: MetadataValidateFunc = (layers) => { +const validateLayers: MetadataValidateFunc = async (layers) => { const validatedLayers: Layer[] = []; for (const l of layers) { try { - validateLayer(l); - validatedLayers.push(l); + const validated = await validateLayer(l); + validatedLayers.push(validated); } catch { // This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid. }