mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): handle missing images/models when recalling control layers
This commit is contained in:
parent
00f36cb491
commit
23ad6fb730
@ -1,4 +1,5 @@
|
|||||||
import { getStore } from 'app/store/nanostores/store';
|
import { getStore } from 'app/store/nanostores/store';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
import {
|
import {
|
||||||
controlAdapterRecalled,
|
controlAdapterRecalled,
|
||||||
controlNetsReset,
|
controlNetsReset,
|
||||||
@ -28,6 +29,7 @@ import type {
|
|||||||
MetadataRecallFunc,
|
MetadataRecallFunc,
|
||||||
T2IAdapterConfigMetadata,
|
T2IAdapterConfigMetadata,
|
||||||
} from 'features/metadata/types';
|
} from 'features/metadata/types';
|
||||||
|
import { fetchModelConfigByIdentifier } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { modelSelected } from 'features/parameters/store/actions';
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
setCfgRescaleMultiplier,
|
setCfgRescaleMultiplier,
|
||||||
@ -69,6 +71,7 @@ import {
|
|||||||
setRefinerStart,
|
setRefinerStart,
|
||||||
setRefinerSteps,
|
setRefinerSteps,
|
||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
|
import { getImageDTO } from 'services/api/endpoints/images';
|
||||||
|
|
||||||
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
||||||
getStore().dispatch(positivePromptChanged(positivePrompt));
|
getStore().dispatch(positivePromptChanged(positivePrompt));
|
||||||
@ -237,21 +240,79 @@ const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapt
|
|||||||
};
|
};
|
||||||
|
|
||||||
//#region Control Layers
|
//#region Control Layers
|
||||||
const recallLayer: MetadataRecallFunc<Layer> = (layer) => {
|
const recallLayer: MetadataRecallFunc<Layer> = async (layer) => {
|
||||||
const { dispatch } = getStore();
|
const { dispatch } = getStore();
|
||||||
switch (layer.type) {
|
// We need to check for the existence of all images and models when recalling. If they do not exist, SMITE THEM!
|
||||||
case 'control_adapter_layer':
|
if (layer.type === 'control_adapter_layer') {
|
||||||
dispatch(caLayerRecalled(layer));
|
const clone = deepClone(layer);
|
||||||
break;
|
if (clone.controlAdapter.image) {
|
||||||
case 'ip_adapter_layer':
|
const imageDTO = await getImageDTO(clone.controlAdapter.image.name);
|
||||||
dispatch(ipaLayerRecalled(layer));
|
if (!imageDTO) {
|
||||||
break;
|
clone.controlAdapter.image = null;
|
||||||
case 'regional_guidance_layer':
|
}
|
||||||
dispatch(rgLayerRecalled(layer));
|
}
|
||||||
break;
|
if (clone.controlAdapter.processedImage) {
|
||||||
case 'initial_image_layer':
|
const imageDTO = await getImageDTO(clone.controlAdapter.processedImage.name);
|
||||||
dispatch(iiLayerRecalled(layer));
|
if (!imageDTO) {
|
||||||
break;
|
clone.controlAdapter.processedImage = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (clone.controlAdapter.model) {
|
||||||
|
try {
|
||||||
|
await fetchModelConfigByIdentifier(clone.controlAdapter.model);
|
||||||
|
} catch {
|
||||||
|
clone.controlAdapter.model = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dispatch(caLayerRecalled(clone));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (layer.type === 'ip_adapter_layer') {
|
||||||
|
const clone = deepClone(layer);
|
||||||
|
if (clone.ipAdapter.image) {
|
||||||
|
const imageDTO = await getImageDTO(clone.ipAdapter.image.name);
|
||||||
|
if (!imageDTO) {
|
||||||
|
clone.ipAdapter.image = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (clone.ipAdapter.model) {
|
||||||
|
try {
|
||||||
|
await fetchModelConfigByIdentifier(clone.ipAdapter.model);
|
||||||
|
} catch {
|
||||||
|
clone.ipAdapter.model = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dispatch(ipaLayerRecalled(clone));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (layer.type === 'regional_guidance_layer') {
|
||||||
|
const clone = deepClone(layer);
|
||||||
|
// Strip out the uploaded mask image property - this is an intermediate image
|
||||||
|
clone.uploadedMaskImage = null;
|
||||||
|
|
||||||
|
for (const ipAdapter of clone.ipAdapters) {
|
||||||
|
if (ipAdapter.image) {
|
||||||
|
const imageDTO = await getImageDTO(ipAdapter.image.name);
|
||||||
|
if (!imageDTO) {
|
||||||
|
ipAdapter.image = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (ipAdapter.model) {
|
||||||
|
try {
|
||||||
|
await fetchModelConfigByIdentifier(ipAdapter.model);
|
||||||
|
} catch {
|
||||||
|
ipAdapter.model = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dispatch(rgLayerRecalled(clone));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (layer.type === 'initial_image_layer') {
|
||||||
|
dispatch(iiLayerRecalled(layer));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ import type {
|
|||||||
MetadataValidateFunc,
|
MetadataValidateFunc,
|
||||||
T2IAdapterConfigMetadata,
|
T2IAdapterConfigMetadata,
|
||||||
} from 'features/metadata/types';
|
} from 'features/metadata/types';
|
||||||
import { fetchModelConfigByIdentifier, InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import type { BaseModelType } from 'services/api/types';
|
import type { BaseModelType } from 'services/api/types';
|
||||||
import { assert } from 'tsafe';
|
import { assert } from 'tsafe';
|
||||||
@ -115,32 +115,29 @@ const validateLayer: MetadataValidateFunc<Layer> = async (layer) => {
|
|||||||
const model = layer.controlAdapter.model;
|
const model = layer.controlAdapter.model;
|
||||||
assert(model, 'Control Adapter layer missing model');
|
assert(model, 'Control Adapter layer missing model');
|
||||||
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
|
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
|
||||||
fetchModelConfigByIdentifier(model);
|
|
||||||
}
|
}
|
||||||
if (layer.type === 'ip_adapter_layer') {
|
if (layer.type === 'ip_adapter_layer') {
|
||||||
const model = layer.ipAdapter.model;
|
const model = layer.ipAdapter.model;
|
||||||
assert(model, 'IP Adapter layer missing model');
|
assert(model, 'IP Adapter layer missing model');
|
||||||
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
|
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
|
||||||
fetchModelConfigByIdentifier(model);
|
|
||||||
}
|
}
|
||||||
if (layer.type === 'regional_guidance_layer') {
|
if (layer.type === 'regional_guidance_layer') {
|
||||||
for (const ipa of layer.ipAdapters) {
|
for (const ipa of layer.ipAdapters) {
|
||||||
const model = ipa.model;
|
const model = ipa.model;
|
||||||
assert(model, 'IP Adapter layer missing model');
|
assert(model, 'IP Adapter layer missing model');
|
||||||
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
|
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
|
||||||
fetchModelConfigByIdentifier(model);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return layer;
|
return layer;
|
||||||
};
|
};
|
||||||
|
|
||||||
const validateLayers: MetadataValidateFunc<Layer[]> = (layers) => {
|
const validateLayers: MetadataValidateFunc<Layer[]> = async (layers) => {
|
||||||
const validatedLayers: Layer[] = [];
|
const validatedLayers: Layer[] = [];
|
||||||
for (const l of layers) {
|
for (const l of layers) {
|
||||||
try {
|
try {
|
||||||
validateLayer(l);
|
const validated = await validateLayer(l);
|
||||||
validatedLayers.push(l);
|
validatedLayers.push(validated);
|
||||||
} catch {
|
} catch {
|
||||||
// This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid.
|
// This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid.
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user