feat(ui): handle missing images/models when recalling control layers

This commit is contained in:
psychedelicious 2024-05-08 15:49:19 +10:00 committed by Kent Keirsey
parent 00f36cb491
commit 23ad6fb730
2 changed files with 79 additions and 21 deletions

View File

@ -1,4 +1,5 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import { import {
controlAdapterRecalled, controlAdapterRecalled,
controlNetsReset, controlNetsReset,
@ -28,6 +29,7 @@ import type {
MetadataRecallFunc, MetadataRecallFunc,
T2IAdapterConfigMetadata, T2IAdapterConfigMetadata,
} from 'features/metadata/types'; } from 'features/metadata/types';
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import { modelSelected } from 'features/parameters/store/actions'; import { modelSelected } from 'features/parameters/store/actions';
import { import {
setCfgRescaleMultiplier, setCfgRescaleMultiplier,
@ -69,6 +71,7 @@ import {
setRefinerStart, setRefinerStart,
setRefinerSteps, setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice'; } from 'features/sdxl/store/sdxlSlice';
import { getImageDTO } from 'services/api/endpoints/images';
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => { const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
getStore().dispatch(positivePromptChanged(positivePrompt)); getStore().dispatch(positivePromptChanged(positivePrompt));
@ -237,21 +240,79 @@ const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapt
}; };
//#region Control Layers //#region Control Layers
const recallLayer: MetadataRecallFunc<Layer> = (layer) => { const recallLayer: MetadataRecallFunc<Layer> = async (layer) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
switch (layer.type) { // We need to check for the existence of all images and models when recalling. If they do not exist, SMITE THEM!
case 'control_adapter_layer': if (layer.type === 'control_adapter_layer') {
dispatch(caLayerRecalled(layer)); const clone = deepClone(layer);
break; if (clone.controlAdapter.image) {
case 'ip_adapter_layer': const imageDTO = await getImageDTO(clone.controlAdapter.image.name);
dispatch(ipaLayerRecalled(layer)); if (!imageDTO) {
break; clone.controlAdapter.image = null;
case 'regional_guidance_layer': }
dispatch(rgLayerRecalled(layer)); }
break; if (clone.controlAdapter.processedImage) {
case 'initial_image_layer': const imageDTO = await getImageDTO(clone.controlAdapter.processedImage.name);
dispatch(iiLayerRecalled(layer)); if (!imageDTO) {
break; clone.controlAdapter.processedImage = null;
}
}
if (clone.controlAdapter.model) {
try {
await fetchModelConfigByIdentifier(clone.controlAdapter.model);
} catch {
clone.controlAdapter.model = null;
}
}
dispatch(caLayerRecalled(clone));
return;
}
if (layer.type === 'ip_adapter_layer') {
const clone = deepClone(layer);
if (clone.ipAdapter.image) {
const imageDTO = await getImageDTO(clone.ipAdapter.image.name);
if (!imageDTO) {
clone.ipAdapter.image = null;
}
}
if (clone.ipAdapter.model) {
try {
await fetchModelConfigByIdentifier(clone.ipAdapter.model);
} catch {
clone.ipAdapter.model = null;
}
}
dispatch(ipaLayerRecalled(clone));
return;
}
if (layer.type === 'regional_guidance_layer') {
const clone = deepClone(layer);
// Strip out the uploaded mask image property - this is an intermediate image
clone.uploadedMaskImage = null;
for (const ipAdapter of clone.ipAdapters) {
if (ipAdapter.image) {
const imageDTO = await getImageDTO(ipAdapter.image.name);
if (!imageDTO) {
ipAdapter.image = null;
}
}
if (ipAdapter.model) {
try {
await fetchModelConfigByIdentifier(ipAdapter.model);
} catch {
ipAdapter.model = null;
}
}
}
dispatch(rgLayerRecalled(clone));
return;
}
if (layer.type === 'initial_image_layer') {
dispatch(iiLayerRecalled(layer));
return;
} }
}; };

View File

@ -7,7 +7,7 @@ import type {
MetadataValidateFunc, MetadataValidateFunc,
T2IAdapterConfigMetadata, T2IAdapterConfigMetadata,
} from 'features/metadata/types'; } from 'features/metadata/types';
import { fetchModelConfigByIdentifier, InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers'; import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import type { BaseModelType } from 'services/api/types'; import type { BaseModelType } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -115,32 +115,29 @@ const validateLayer: MetadataValidateFunc<Layer> = async (layer) => {
const model = layer.controlAdapter.model; const model = layer.controlAdapter.model;
assert(model, 'Control Adapter layer missing model'); assert(model, 'Control Adapter layer missing model');
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
fetchModelConfigByIdentifier(model);
} }
if (layer.type === 'ip_adapter_layer') { if (layer.type === 'ip_adapter_layer') {
const model = layer.ipAdapter.model; const model = layer.ipAdapter.model;
assert(model, 'IP Adapter layer missing model'); assert(model, 'IP Adapter layer missing model');
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
fetchModelConfigByIdentifier(model);
} }
if (layer.type === 'regional_guidance_layer') { if (layer.type === 'regional_guidance_layer') {
for (const ipa of layer.ipAdapters) { for (const ipa of layer.ipAdapters) {
const model = ipa.model; const model = ipa.model;
assert(model, 'IP Adapter layer missing model'); assert(model, 'IP Adapter layer missing model');
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
fetchModelConfigByIdentifier(model);
} }
} }
return layer; return layer;
}; };
const validateLayers: MetadataValidateFunc<Layer[]> = (layers) => { const validateLayers: MetadataValidateFunc<Layer[]> = async (layers) => {
const validatedLayers: Layer[] = []; const validatedLayers: Layer[] = [];
for (const l of layers) { for (const l of layers) {
try { try {
validateLayer(l); const validated = await validateLayer(l);
validatedLayers.push(l); validatedLayers.push(validated);
} catch { } catch {
// This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid. // This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid.
} }