feat(ui): handle missing images/models when recalling control layers

This commit is contained in:
psychedelicious 2024-05-08 15:49:19 +10:00 committed by Kent Keirsey
parent 00f36cb491
commit 23ad6fb730
2 changed files with 79 additions and 21 deletions

View File

@ -1,4 +1,5 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import {
controlAdapterRecalled,
controlNetsReset,
@ -28,6 +29,7 @@ import type {
MetadataRecallFunc,
T2IAdapterConfigMetadata,
} from 'features/metadata/types';
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
import { modelSelected } from 'features/parameters/store/actions';
import {
setCfgRescaleMultiplier,
@ -69,6 +71,7 @@ import {
setRefinerStart,
setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice';
import { getImageDTO } from 'services/api/endpoints/images';
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
getStore().dispatch(positivePromptChanged(positivePrompt));
@ -237,21 +240,79 @@ const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapt
};
//#region Control Layers
const recallLayer: MetadataRecallFunc<Layer> = (layer) => {
const recallLayer: MetadataRecallFunc<Layer> = async (layer) => {
const { dispatch } = getStore();
switch (layer.type) {
case 'control_adapter_layer':
dispatch(caLayerRecalled(layer));
break;
case 'ip_adapter_layer':
dispatch(ipaLayerRecalled(layer));
break;
case 'regional_guidance_layer':
dispatch(rgLayerRecalled(layer));
break;
case 'initial_image_layer':
// We need to check for the existence of all images and models when recalling. If they do not exist, SMITE THEM!
if (layer.type === 'control_adapter_layer') {
const clone = deepClone(layer);
if (clone.controlAdapter.image) {
const imageDTO = await getImageDTO(clone.controlAdapter.image.name);
if (!imageDTO) {
clone.controlAdapter.image = null;
}
}
if (clone.controlAdapter.processedImage) {
const imageDTO = await getImageDTO(clone.controlAdapter.processedImage.name);
if (!imageDTO) {
clone.controlAdapter.processedImage = null;
}
}
if (clone.controlAdapter.model) {
try {
await fetchModelConfigByIdentifier(clone.controlAdapter.model);
} catch {
clone.controlAdapter.model = null;
}
}
dispatch(caLayerRecalled(clone));
return;
}
if (layer.type === 'ip_adapter_layer') {
const clone = deepClone(layer);
if (clone.ipAdapter.image) {
const imageDTO = await getImageDTO(clone.ipAdapter.image.name);
if (!imageDTO) {
clone.ipAdapter.image = null;
}
}
if (clone.ipAdapter.model) {
try {
await fetchModelConfigByIdentifier(clone.ipAdapter.model);
} catch {
clone.ipAdapter.model = null;
}
}
dispatch(ipaLayerRecalled(clone));
return;
}
if (layer.type === 'regional_guidance_layer') {
const clone = deepClone(layer);
// Strip out the uploaded mask image property - this is an intermediate image
clone.uploadedMaskImage = null;
for (const ipAdapter of clone.ipAdapters) {
if (ipAdapter.image) {
const imageDTO = await getImageDTO(ipAdapter.image.name);
if (!imageDTO) {
ipAdapter.image = null;
}
}
if (ipAdapter.model) {
try {
await fetchModelConfigByIdentifier(ipAdapter.model);
} catch {
ipAdapter.model = null;
}
}
}
dispatch(rgLayerRecalled(clone));
return;
}
if (layer.type === 'initial_image_layer') {
dispatch(iiLayerRecalled(layer));
break;
return;
}
};

View File

@ -7,7 +7,7 @@ import type {
MetadataValidateFunc,
T2IAdapterConfigMetadata,
} from 'features/metadata/types';
import { fetchModelConfigByIdentifier, InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import type { BaseModelType } from 'services/api/types';
import { assert } from 'tsafe';
@ -115,32 +115,29 @@ const validateLayer: MetadataValidateFunc<Layer> = async (layer) => {
const model = layer.controlAdapter.model;
assert(model, 'Control Adapter layer missing model');
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
fetchModelConfigByIdentifier(model);
}
if (layer.type === 'ip_adapter_layer') {
const model = layer.ipAdapter.model;
assert(model, 'IP Adapter layer missing model');
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
fetchModelConfigByIdentifier(model);
}
if (layer.type === 'regional_guidance_layer') {
for (const ipa of layer.ipAdapters) {
const model = ipa.model;
assert(model, 'IP Adapter layer missing model');
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
fetchModelConfigByIdentifier(model);
}
}
return layer;
};
const validateLayers: MetadataValidateFunc<Layer[]> = (layers) => {
const validateLayers: MetadataValidateFunc<Layer[]> = async (layers) => {
const validatedLayers: Layer[] = [];
for (const l of layers) {
try {
validateLayer(l);
validatedLayers.push(l);
const validated = await validateLayer(l);
validatedLayers.push(validated);
} catch {
// This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid.
}