feat(ui): individual layer recall

This commit is contained in:
psychedelicious 2024-05-08 09:58:34 +10:00 committed by Kent Keirsey
parent b43b2714cc
commit dfbd7eb1cf
8 changed files with 192 additions and 39 deletions

View File

@ -124,6 +124,12 @@ const getVectorMaskPreviewColor = (state: ControlLayersState): RgbColor => {
const lastColor = rgLayers[rgLayers.length - 1]?.previewColor; const lastColor = rgLayers[rgLayers.length - 1]?.previewColor;
return LayerColors.next(lastColor); return LayerColors.next(lastColor);
}; };
const deselectAllLayers = (state: ControlLayersState) => {
for (const layer of state.layers.filter(isRenderableLayer)) {
layer.isSelected = false;
}
state.selectedLayerId = null;
};
export const controlLayersSlice = createSlice({ export const controlLayersSlice = createSlice({
name: 'controlLayers', name: 'controlLayers',
@ -256,6 +262,7 @@ export const controlLayersSlice = createSlice({
}), }),
}, },
caLayerRecalled: (state, action: PayloadAction<ControlAdapterLayer>) => { caLayerRecalled: (state, action: PayloadAction<ControlAdapterLayer>) => {
deselectAllLayers(state);
state.layers.push({ ...action.payload, isSelected: true }); state.layers.push({ ...action.payload, isSelected: true });
state.selectedLayerId = action.payload.id; state.selectedLayerId = action.payload.id;
}, },
@ -470,6 +477,7 @@ export const controlLayersSlice = createSlice({
prepare: () => ({ payload: { layerId: uuidv4() } }), prepare: () => ({ payload: { layerId: uuidv4() } }),
}, },
rgLayerRecalled: (state, action: PayloadAction<RegionalGuidanceLayer>) => { rgLayerRecalled: (state, action: PayloadAction<RegionalGuidanceLayer>) => {
deselectAllLayers(state);
state.layers.push({ ...action.payload, isSelected: true }); state.layers.push({ ...action.payload, isSelected: true });
state.selectedLayerId = action.payload.id; state.selectedLayerId = action.payload.id;
}, },
@ -665,6 +673,12 @@ export const controlLayersSlice = createSlice({
}, },
prepare: (imageDTO: ImageDTO | null) => ({ payload: { layerId: 'initial_image_layer', imageDTO } }), prepare: (imageDTO: ImageDTO | null) => ({ payload: { layerId: 'initial_image_layer', imageDTO } }),
}, },
iiLayerRecalled: (state, action: PayloadAction<InitialImageLayer>) => {
deselectAllLayers(state);
state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true));
state.layers.push({ ...action.payload, isSelected: true });
state.selectedLayerId = action.payload.id;
},
iiLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { iiLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload; const { layerId, imageDTO } = action.payload;
const layer = selectIILayerOrThrow(state, layerId); const layer = selectIILayerOrThrow(state, layerId);
@ -859,6 +873,7 @@ export const {
rgLayerIPAdapterCLIPVisionModelChanged, rgLayerIPAdapterCLIPVisionModelChanged,
// II Layer // II Layer
iiLayerAdded, iiLayerAdded,
iiLayerRecalled,
iiLayerImageChanged, iiLayerImageChanged,
iiLayerOpacityChanged, iiLayerOpacityChanged,
iiLayerDenoisingStrengthChanged, iiLayerDenoisingStrengthChanged,

View File

@ -4,6 +4,7 @@ import { MetadataControlNetsV2 } from 'features/metadata/components/MetadataCont
import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters'; import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters';
import { MetadataIPAdaptersV2 } from 'features/metadata/components/MetadataIPAdaptersV2'; import { MetadataIPAdaptersV2 } from 'features/metadata/components/MetadataIPAdaptersV2';
import { MetadataItem } from 'features/metadata/components/MetadataItem'; import { MetadataItem } from 'features/metadata/components/MetadataItem';
import { MetadataLayers } from 'features/metadata/components/MetadataLayers';
import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs'; import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs';
import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters'; import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters';
import { MetadataT2IAdaptersV2 } from 'features/metadata/components/MetadataT2IAdaptersV2'; import { MetadataT2IAdaptersV2 } from 'features/metadata/components/MetadataT2IAdaptersV2';
@ -51,8 +52,8 @@ const ImageMetadataActions = (props: Props) => {
<MetadataItem metadata={metadata} handlers={handlers.refinerScheduler} /> <MetadataItem metadata={metadata} handlers={handlers.refinerScheduler} />
<MetadataItem metadata={metadata} handlers={handlers.refinerStart} /> <MetadataItem metadata={metadata} handlers={handlers.refinerStart} />
<MetadataItem metadata={metadata} handlers={handlers.refinerSteps} /> <MetadataItem metadata={metadata} handlers={handlers.refinerSteps} />
<MetadataItem metadata={metadata} handlers={handlers.layers} />
<MetadataLoRAs metadata={metadata} /> <MetadataLoRAs metadata={metadata} />
<MetadataLayers metadata={metadata} />
{activeTabName !== 'generation' && <MetadataControlNets metadata={metadata} />} {activeTabName !== 'generation' && <MetadataControlNets metadata={metadata} />}
{activeTabName !== 'generation' && <MetadataT2IAdapters metadata={metadata} />} {activeTabName !== 'generation' && <MetadataT2IAdapters metadata={metadata} />}
{activeTabName !== 'generation' && <MetadataIPAdapters metadata={metadata} />} {activeTabName !== 'generation' && <MetadataIPAdapters metadata={metadata} />}

View File

@ -0,0 +1,68 @@
import type { Layer } from 'features/controlLayers/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
type Props = {
metadata: unknown;
};
export const MetadataLayers = ({ metadata }: Props) => {
const [layers, setLayers] = useState<Layer[]>([]);
useEffect(() => {
const parse = async () => {
try {
const parsed = await handlers.layers.parse(metadata);
setLayers(parsed);
} catch (e) {
setLayers([]);
}
};
parse();
}, [metadata]);
const label = useMemo(() => handlers.layers.getLabel(), []);
return (
<>
{layers.map((layer) => (
<MetadataViewLayer key={layer.id} label={label} layer={layer} handlers={handlers.layers} />
))}
</>
);
};
const MetadataViewLayer = ({
label,
layer,
handlers,
}: {
label: string;
layer: Layer;
handlers: MetadataHandlers<Layer[], Layer>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {
return;
}
handlers.recallItem(layer, true);
}, [handlers, layer]);
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
useEffect(() => {
const _renderValue = async () => {
if (!handlers.renderItemValue) {
setRenderedValue(null);
return;
}
const rendered = await handlers.renderItemValue(layer);
setRenderedValue(rendered);
};
_renderValue();
}, [handlers, layer]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
};

View File

@ -17,6 +17,7 @@ import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
import { validators } from 'features/metadata/util/validators'; import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierField } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { t } from 'i18next'; import { t } from 'i18next';
import { assert } from 'tsafe';
import { parsers } from './parsers'; import { parsers } from './parsers';
import { recallers } from './recallers'; import { recallers } from './recallers';
@ -53,8 +54,23 @@ const renderControlAdapterValueV2: MetadataRenderValueFunc<AnyControlAdapterConf
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`; return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
} }
}; };
const renderLayersValue: MetadataRenderValueFunc<Layer[]> = async (value) => { const renderLayerValue: MetadataRenderValueFunc<Layer> = async (layer) => {
return `${value.length} ${t('controlLayers.layers', { count: value.length })}`; if (layer.type === 'initial_image_layer') {
return t('controlLayers.initialImageLayer');
}
if (layer.type === 'control_adapter_layer') {
return t('controlLayers.controlAdapterLayer');
}
if (layer.type === 'ip_adapter_layer') {
return t('controlLayers.ipAdapterLayer');
}
if (layer.type === 'regional_guidance_layer') {
return t('controlLayers.regionalGuidanceLayer');
}
assert(false, 'Unknown layer type');
};
const renderLayersValue: MetadataRenderValueFunc<Layer[]> = async (layers) => {
return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`;
}; };
const parameterSetToast = (parameter: string, description?: string) => { const parameterSetToast = (parameter: string, description?: string) => {
@ -389,8 +405,12 @@ export const handlers = {
layers: buildHandlers({ layers: buildHandlers({
getLabel: () => t('controlLayers.layers_other'), getLabel: () => t('controlLayers.layers_other'),
parser: parsers.layers, parser: parsers.layers,
itemParser: parsers.layer,
recaller: recallers.layers, recaller: recallers.layers,
itemRecaller: recallers.layer,
validator: validators.layers, validator: validators.layers,
itemValidator: validators.layer,
renderItemValue: renderLayerValue,
renderValue: renderLayersValue, renderValue: renderLayersValue,
getIsVisible: (value) => value.length > 0, getIsVisible: (value) => value.length > 0,
}), }),

View File

@ -1,4 +1,5 @@
import { getStore } from 'app/store/nanostores/store'; import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common'; import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models'; import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
@ -68,6 +69,24 @@ const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type:
} }
}; };
/**
* Fetches the model config given an identifier. First attempts to fetch by key, then falls back to fetching by attrs.
* @param identifier The model identifier.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise<AnyModelConfig> => {
try {
return await fetchModelConfig(identifier.key);
} catch {
try {
return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type);
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`);
}
}
};
/** /**
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type. * Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
* @param key The model key. * @param key The model key.

View File

@ -625,19 +625,6 @@ const parseIPAdapterV2: MetadataParseFunc<IPAdapterConfigV2Metadata> = async (me
return ipAdapter; return ipAdapter;
}; };
const parseLayers: MetadataParseFunc<Layer[]> = async (metadata) => {
try {
const layersRaw = await getProperty(metadata, 'layers', isArray);
const parseResults = await Promise.allSettled(layersRaw.map((layerRaw) => zLayer.parseAsync(layerRaw)));
const layers = parseResults
.filter((result): result is PromiseFulfilledResult<Layer> => result.status === 'fulfilled')
.map((result) => result.value);
return layers;
} catch {
return [];
}
};
const parseAllIPAdaptersV2: MetadataParseFunc<IPAdapterConfigV2Metadata[]> = async (metadata) => { const parseAllIPAdaptersV2: MetadataParseFunc<IPAdapterConfigV2Metadata[]> = async (metadata) => {
try { try {
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray); const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
@ -651,6 +638,21 @@ const parseAllIPAdaptersV2: MetadataParseFunc<IPAdapterConfigV2Metadata[]> = asy
} }
}; };
const parseLayer: MetadataParseFunc<Layer> = async (metadataItem) => zLayer.parseAsync(metadataItem);
const parseLayers: MetadataParseFunc<Layer[]> = async (metadata) => {
try {
const layersRaw = await getProperty(metadata, 'layers', isArray);
const parseResults = await Promise.allSettled(layersRaw.map(parseLayer));
const layers = parseResults
.filter((result): result is PromiseFulfilledResult<Layer> => result.status === 'fulfilled')
.map((result) => result.value);
return layers;
} catch {
return [];
}
};
export const parsers = { export const parsers = {
createdBy: parseCreatedBy, createdBy: parseCreatedBy,
generationMode: parseGenerationMode, generationMode: parseGenerationMode,
@ -693,5 +695,6 @@ export const parsers = {
t2iAdaptersV2: parseAllT2IAdaptersV2, t2iAdaptersV2: parseAllT2IAdaptersV2,
ipAdapterV2: parseIPAdapterV2, ipAdapterV2: parseIPAdapterV2,
ipAdaptersV2: parseAllIPAdaptersV2, ipAdaptersV2: parseAllIPAdaptersV2,
layer: parseLayer,
layers: parseLayers, layers: parseLayers,
} as const; } as const;

View File

@ -13,6 +13,7 @@ import {
caLayerT2IAdaptersDeleted, caLayerT2IAdaptersDeleted,
heightChanged, heightChanged,
iiLayerAdded, iiLayerAdded,
iiLayerRecalled,
ipaLayerAdded, ipaLayerAdded,
ipaLayerRecalled, ipaLayerRecalled,
ipaLayersDeleted, ipaLayersDeleted,
@ -295,21 +296,29 @@ const recallIPAdaptersV2: MetadataRecallFunc<IPAdapterConfigV2Metadata[]> = (ipA
}); });
}; };
const recallLayer: MetadataRecallFunc<Layer> = (layer) => {
const { dispatch } = getStore();
switch (layer.type) {
case 'control_adapter_layer':
dispatch(caLayerRecalled(layer));
break;
case 'ip_adapter_layer':
dispatch(ipaLayerRecalled(layer));
break;
case 'regional_guidance_layer':
dispatch(rgLayerRecalled(layer));
break;
case 'initial_image_layer':
dispatch(iiLayerRecalled(layer));
break;
}
};
const recallLayers: MetadataRecallFunc<Layer[]> = (layers) => { const recallLayers: MetadataRecallFunc<Layer[]> = (layers) => {
const { dispatch } = getStore(); const { dispatch } = getStore();
dispatch(allLayersDeleted()); dispatch(allLayersDeleted());
for (const l of layers) { for (const l of layers) {
switch (l.type) { recallLayer(l);
case 'control_adapter_layer':
dispatch(caLayerRecalled(l));
break;
case 'ip_adapter_layer':
dispatch(ipaLayerRecalled(l));
break;
case 'regional_guidance_layer':
dispatch(rgLayerRecalled(l));
break;
}
} }
}; };
@ -353,5 +362,6 @@ export const recallers = {
t2iAdaptersV2: recallT2IAdaptersV2, t2iAdaptersV2: recallT2IAdaptersV2,
ipAdapterV2: recallIPAdapterV2, ipAdapterV2: recallIPAdapterV2,
ipAdaptersV2: recallIPAdaptersV2, ipAdaptersV2: recallIPAdaptersV2,
layer: recallLayer,
layers: recallLayers, layers: recallLayers,
} as const; } as const;

View File

@ -10,9 +10,10 @@ import type {
T2IAdapterConfigMetadata, T2IAdapterConfigMetadata,
T2IAdapterConfigV2Metadata, T2IAdapterConfigV2Metadata,
} from 'features/metadata/types'; } from 'features/metadata/types';
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers'; import { fetchModelConfigByIdentifier, InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import type { BaseModelType } from 'services/api/types'; import type { BaseModelType } from 'services/api/types';
import { assert } from 'tsafe';
/** /**
* Checks the given base model type against the currently-selected model's base type and throws an error if they are * Checks the given base model type against the currently-selected model's base type and throws an error if they are
@ -166,21 +167,36 @@ const validateIPAdaptersV2: MetadataValidateFunc<IPAdapterConfigV2Metadata[]> =
return new Promise((resolve) => resolve(validatedIPAdapters)); return new Promise((resolve) => resolve(validatedIPAdapters));
}; };
const validateLayer: MetadataValidateFunc<Layer> = async (layer) => {
if (layer.type === 'control_adapter_layer') {
const model = layer.controlAdapter.model;
assert(model, 'Control Adapter layer missing model');
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
fetchModelConfigByIdentifier(model);
}
if (layer.type === 'ip_adapter_layer') {
const model = layer.ipAdapter.model;
assert(model, 'IP Adapter layer missing model');
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
fetchModelConfigByIdentifier(model);
}
if (layer.type === 'regional_guidance_layer') {
for (const ipa of layer.ipAdapters) {
const model = ipa.model;
assert(model, 'IP Adapter layer missing model');
validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model');
fetchModelConfigByIdentifier(model);
}
}
return layer;
};
const validateLayers: MetadataValidateFunc<Layer[]> = (layers) => { const validateLayers: MetadataValidateFunc<Layer[]> = (layers) => {
const validatedLayers: Layer[] = []; const validatedLayers: Layer[] = [];
for (const l of layers) { for (const l of layers) {
try { try {
if (l.type === 'control_adapter_layer') { validateLayer(l);
validateBaseCompatibility(l.controlAdapter.model?.base, 'Layer incompatible with currently-selected model');
}
if (l.type === 'ip_adapter_layer') {
validateBaseCompatibility(l.ipAdapter.model?.base, 'Layer incompatible with currently-selected model');
}
if (l.type === 'regional_guidance_layer') {
for (const ipa of l.ipAdapters) {
validateBaseCompatibility(ipa.model?.base, 'Layer incompatible with currently-selected model');
}
}
validatedLayers.push(l); validatedLayers.push(l);
} catch { } catch {
// This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid. // This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid.
@ -206,5 +222,6 @@ export const validators = {
t2iAdaptersV2: validateT2IAdaptersV2, t2iAdaptersV2: validateT2IAdaptersV2,
ipAdapterV2: validateIPAdapterV2, ipAdapterV2: validateIPAdapterV2,
ipAdaptersV2: validateIPAdaptersV2, ipAdaptersV2: validateIPAdaptersV2,
layer: validateLayer,
layers: validateLayers, layers: validateLayers,
} as const; } as const;