From dfbd7eb1cfc8bd02eef63d6f3a0924513480f8bb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 8 May 2024 09:58:34 +1000 Subject: [PATCH] feat(ui): individual layer recall --- .../controlLayers/store/controlLayersSlice.ts | 15 ++++ .../ImageMetadataActions.tsx | 3 +- .../metadata/components/MetadataLayers.tsx | 68 +++++++++++++++++++ .../src/features/metadata/util/handlers.ts | 24 ++++++- .../metadata/util/modelFetchingHelpers.ts | 19 ++++++ .../web/src/features/metadata/util/parsers.ts | 29 ++++---- .../src/features/metadata/util/recallers.ts | 32 ++++++--- .../src/features/metadata/util/validators.ts | 41 +++++++---- 8 files changed, 192 insertions(+), 39 deletions(-) create mode 100644 invokeai/frontend/web/src/features/metadata/components/MetadataLayers.tsx diff --git a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts index 6f6176c242..bc9f133075 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts @@ -124,6 +124,12 @@ const getVectorMaskPreviewColor = (state: ControlLayersState): RgbColor => { const lastColor = rgLayers[rgLayers.length - 1]?.previewColor; return LayerColors.next(lastColor); }; +const deselectAllLayers = (state: ControlLayersState) => { + for (const layer of state.layers.filter(isRenderableLayer)) { + layer.isSelected = false; + } + state.selectedLayerId = null; +}; export const controlLayersSlice = createSlice({ name: 'controlLayers', @@ -256,6 +262,7 @@ export const controlLayersSlice = createSlice({ }), }, caLayerRecalled: (state, action: PayloadAction) => { + deselectAllLayers(state); state.layers.push({ ...action.payload, isSelected: true }); state.selectedLayerId = action.payload.id; }, @@ -470,6 +477,7 @@ export const controlLayersSlice = createSlice({ prepare: () => ({ payload: { layerId: uuidv4() } }), }, rgLayerRecalled: (state, action: PayloadAction) => { + deselectAllLayers(state); state.layers.push({ ...action.payload, isSelected: true }); state.selectedLayerId = action.payload.id; }, @@ -665,6 +673,12 @@ export const controlLayersSlice = createSlice({ }, prepare: (imageDTO: ImageDTO | null) => ({ payload: { layerId: 'initial_image_layer', imageDTO } }), }, + iiLayerRecalled: (state, action: PayloadAction) => { + deselectAllLayers(state); + state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true)); + state.layers.push({ ...action.payload, isSelected: true }); + state.selectedLayerId = action.payload.id; + }, iiLayerImageChanged: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { const { layerId, imageDTO } = action.payload; const layer = selectIILayerOrThrow(state, layerId); @@ -859,6 +873,7 @@ export const { rgLayerIPAdapterCLIPVisionModelChanged, // II Layer iiLayerAdded, + iiLayerRecalled, iiLayerImageChanged, iiLayerOpacityChanged, iiLayerDenoisingStrengthChanged, diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 7dd2be55b0..04e8fd2eca 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -4,6 +4,7 @@ import { MetadataControlNetsV2 } from 'features/metadata/components/MetadataCont import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters'; import { MetadataIPAdaptersV2 } from 'features/metadata/components/MetadataIPAdaptersV2'; import { MetadataItem } from 'features/metadata/components/MetadataItem'; +import { MetadataLayers } from 'features/metadata/components/MetadataLayers'; import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs'; import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters'; import { MetadataT2IAdaptersV2 } from 'features/metadata/components/MetadataT2IAdaptersV2'; @@ -51,8 +52,8 @@ const ImageMetadataActions = (props: Props) => { - + {activeTabName !== 'generation' && } {activeTabName !== 'generation' && } {activeTabName !== 'generation' && } diff --git a/invokeai/frontend/web/src/features/metadata/components/MetadataLayers.tsx b/invokeai/frontend/web/src/features/metadata/components/MetadataLayers.tsx new file mode 100644 index 0000000000..ab4ce03987 --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/components/MetadataLayers.tsx @@ -0,0 +1,68 @@ +import type { Layer } from 'features/controlLayers/store/types'; +import { MetadataItemView } from 'features/metadata/components/MetadataItemView'; +import type { MetadataHandlers } from 'features/metadata/types'; +import { handlers } from 'features/metadata/util/handlers'; +import { useCallback, useEffect, useMemo, useState } from 'react'; + +type Props = { + metadata: unknown; +}; + +export const MetadataLayers = ({ metadata }: Props) => { + const [layers, setLayers] = useState([]); + + useEffect(() => { + const parse = async () => { + try { + const parsed = await handlers.layers.parse(metadata); + setLayers(parsed); + } catch (e) { + setLayers([]); + } + }; + parse(); + }, [metadata]); + + const label = useMemo(() => handlers.layers.getLabel(), []); + + return ( + <> + {layers.map((layer) => ( + + ))} + + ); +}; + +const MetadataViewLayer = ({ + label, + layer, + handlers, +}: { + label: string; + layer: Layer; + handlers: MetadataHandlers; +}) => { + const onRecall = useCallback(() => { + if (!handlers.recallItem) { + return; + } + handlers.recallItem(layer, true); + }, [handlers, layer]); + + const [renderedValue, setRenderedValue] = useState(null); + useEffect(() => { + const _renderValue = async () => { + if (!handlers.renderItemValue) { + setRenderedValue(null); + return; + } + const rendered = await handlers.renderItemValue(layer); + setRenderedValue(rendered); + }; + + _renderValue(); + }, [handlers, layer]); + + return ; +}; diff --git a/invokeai/frontend/web/src/features/metadata/util/handlers.ts b/invokeai/frontend/web/src/features/metadata/util/handlers.ts index 4cbe69668f..a2ba3dfc22 100644 --- a/invokeai/frontend/web/src/features/metadata/util/handlers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/handlers.ts @@ -17,6 +17,7 @@ import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers'; import { validators } from 'features/metadata/util/validators'; import type { ModelIdentifierField } from 'features/nodes/types/common'; import { t } from 'i18next'; +import { assert } from 'tsafe'; import { parsers } from './parsers'; import { recallers } from './recallers'; @@ -53,8 +54,23 @@ const renderControlAdapterValueV2: MetadataRenderValueFunc = async (value) => { - return `${value.length} ${t('controlLayers.layers', { count: value.length })}`; +const renderLayerValue: MetadataRenderValueFunc = async (layer) => { + if (layer.type === 'initial_image_layer') { + return t('controlLayers.initialImageLayer'); + } + if (layer.type === 'control_adapter_layer') { + return t('controlLayers.controlAdapterLayer'); + } + if (layer.type === 'ip_adapter_layer') { + return t('controlLayers.ipAdapterLayer'); + } + if (layer.type === 'regional_guidance_layer') { + return t('controlLayers.regionalGuidanceLayer'); + } + assert(false, 'Unknown layer type'); +}; +const renderLayersValue: MetadataRenderValueFunc = async (layers) => { + return `${layers.length} ${t('controlLayers.layers', { count: layers.length })}`; }; const parameterSetToast = (parameter: string, description?: string) => { @@ -389,8 +405,12 @@ export const handlers = { layers: buildHandlers({ getLabel: () => t('controlLayers.layers_other'), parser: parsers.layers, + itemParser: parsers.layer, recaller: recallers.layers, + itemRecaller: recallers.layer, validator: validators.layers, + itemValidator: validators.layer, + renderItemValue: renderLayerValue, renderValue: renderLayersValue, getIsVisible: (value) => value.length > 0, }), diff --git a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts index a237582ed8..a2db414937 100644 --- a/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/modelFetchingHelpers.ts @@ -1,4 +1,5 @@ import { getStore } from 'app/store/nanostores/store'; +import type { ModelIdentifierField } from 'features/nodes/types/common'; import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common'; import { modelsApi } from 'services/api/endpoints/models'; import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types'; @@ -68,6 +69,24 @@ const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: } }; +/** + * Fetches the model config given an identifier. First attempts to fetch by key, then falls back to fetching by attrs. + * @param identifier The model identifier. + * @returns A promise that resolves to the model config. + * @throws {ModelConfigNotFoundError} If the model config is unable to be fetched. + */ +export const fetchModelConfigByIdentifier = async (identifier: ModelIdentifierField): Promise => { + try { + return await fetchModelConfig(identifier.key); + } catch { + try { + return await fetchModelConfigByAttrs(identifier.name, identifier.base, identifier.type); + } catch { + throw new ModelConfigNotFoundError(`Unable to retrieve model config for identifier ${identifier}`); + } + } +}; + /** * Fetches the model config for a given model key and type, and ensures that the model config is of a specific type. * @param key The model key. diff --git a/invokeai/frontend/web/src/features/metadata/util/parsers.ts b/invokeai/frontend/web/src/features/metadata/util/parsers.ts index 25ab72536a..f59bbc90c6 100644 --- a/invokeai/frontend/web/src/features/metadata/util/parsers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/parsers.ts @@ -625,19 +625,6 @@ const parseIPAdapterV2: MetadataParseFunc = async (me return ipAdapter; }; -const parseLayers: MetadataParseFunc = async (metadata) => { - try { - const layersRaw = await getProperty(metadata, 'layers', isArray); - const parseResults = await Promise.allSettled(layersRaw.map((layerRaw) => zLayer.parseAsync(layerRaw))); - const layers = parseResults - .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') - .map((result) => result.value); - return layers; - } catch { - return []; - } -}; - const parseAllIPAdaptersV2: MetadataParseFunc = async (metadata) => { try { const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray); @@ -651,6 +638,21 @@ const parseAllIPAdaptersV2: MetadataParseFunc = asy } }; +const parseLayer: MetadataParseFunc = async (metadataItem) => zLayer.parseAsync(metadataItem); + +const parseLayers: MetadataParseFunc = async (metadata) => { + try { + const layersRaw = await getProperty(metadata, 'layers', isArray); + const parseResults = await Promise.allSettled(layersRaw.map(parseLayer)); + const layers = parseResults + .filter((result): result is PromiseFulfilledResult => result.status === 'fulfilled') + .map((result) => result.value); + return layers; + } catch { + return []; + } +}; + export const parsers = { createdBy: parseCreatedBy, generationMode: parseGenerationMode, @@ -693,5 +695,6 @@ export const parsers = { t2iAdaptersV2: parseAllT2IAdaptersV2, ipAdapterV2: parseIPAdapterV2, ipAdaptersV2: parseAllIPAdaptersV2, + layer: parseLayer, layers: parseLayers, } as const; diff --git a/invokeai/frontend/web/src/features/metadata/util/recallers.ts b/invokeai/frontend/web/src/features/metadata/util/recallers.ts index 3782c789e0..390e840776 100644 --- a/invokeai/frontend/web/src/features/metadata/util/recallers.ts +++ b/invokeai/frontend/web/src/features/metadata/util/recallers.ts @@ -13,6 +13,7 @@ import { caLayerT2IAdaptersDeleted, heightChanged, iiLayerAdded, + iiLayerRecalled, ipaLayerAdded, ipaLayerRecalled, ipaLayersDeleted, @@ -295,21 +296,29 @@ const recallIPAdaptersV2: MetadataRecallFunc = (ipA }); }; +const recallLayer: MetadataRecallFunc = (layer) => { + const { dispatch } = getStore(); + switch (layer.type) { + case 'control_adapter_layer': + dispatch(caLayerRecalled(layer)); + break; + case 'ip_adapter_layer': + dispatch(ipaLayerRecalled(layer)); + break; + case 'regional_guidance_layer': + dispatch(rgLayerRecalled(layer)); + break; + case 'initial_image_layer': + dispatch(iiLayerRecalled(layer)); + break; + } +}; + const recallLayers: MetadataRecallFunc = (layers) => { const { dispatch } = getStore(); dispatch(allLayersDeleted()); for (const l of layers) { - switch (l.type) { - case 'control_adapter_layer': - dispatch(caLayerRecalled(l)); - break; - case 'ip_adapter_layer': - dispatch(ipaLayerRecalled(l)); - break; - case 'regional_guidance_layer': - dispatch(rgLayerRecalled(l)); - break; - } + recallLayer(l); } }; @@ -353,5 +362,6 @@ export const recallers = { t2iAdaptersV2: recallT2IAdaptersV2, ipAdapterV2: recallIPAdapterV2, ipAdaptersV2: recallIPAdaptersV2, + layer: recallLayer, layers: recallLayers, } as const; diff --git a/invokeai/frontend/web/src/features/metadata/util/validators.ts b/invokeai/frontend/web/src/features/metadata/util/validators.ts index aca988f85a..7381d7aee0 100644 --- a/invokeai/frontend/web/src/features/metadata/util/validators.ts +++ b/invokeai/frontend/web/src/features/metadata/util/validators.ts @@ -10,9 +10,10 @@ import type { T2IAdapterConfigMetadata, T2IAdapterConfigV2Metadata, } from 'features/metadata/types'; -import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers'; +import { fetchModelConfigByIdentifier, InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers'; import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas'; import type { BaseModelType } from 'services/api/types'; +import { assert } from 'tsafe'; /** * Checks the given base model type against the currently-selected model's base type and throws an error if they are @@ -166,21 +167,36 @@ const validateIPAdaptersV2: MetadataValidateFunc = return new Promise((resolve) => resolve(validatedIPAdapters)); }; +const validateLayer: MetadataValidateFunc = async (layer) => { + if (layer.type === 'control_adapter_layer') { + const model = layer.controlAdapter.model; + assert(model, 'Control Adapter layer missing model'); + validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); + fetchModelConfigByIdentifier(model); + } + if (layer.type === 'ip_adapter_layer') { + const model = layer.ipAdapter.model; + assert(model, 'IP Adapter layer missing model'); + validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); + fetchModelConfigByIdentifier(model); + } + if (layer.type === 'regional_guidance_layer') { + for (const ipa of layer.ipAdapters) { + const model = ipa.model; + assert(model, 'IP Adapter layer missing model'); + validateBaseCompatibility(model.base, 'Layer incompatible with currently-selected model'); + fetchModelConfigByIdentifier(model); + } + } + + return layer; +}; + const validateLayers: MetadataValidateFunc = (layers) => { const validatedLayers: Layer[] = []; for (const l of layers) { try { - if (l.type === 'control_adapter_layer') { - validateBaseCompatibility(l.controlAdapter.model?.base, 'Layer incompatible with currently-selected model'); - } - if (l.type === 'ip_adapter_layer') { - validateBaseCompatibility(l.ipAdapter.model?.base, 'Layer incompatible with currently-selected model'); - } - if (l.type === 'regional_guidance_layer') { - for (const ipa of l.ipAdapters) { - validateBaseCompatibility(ipa.model?.base, 'Layer incompatible with currently-selected model'); - } - } + validateLayer(l); validatedLayers.push(l); } catch { // This is a no-op - we want to continue validating the rest of the layers, and an empty list is valid. @@ -206,5 +222,6 @@ export const validators = { t2iAdaptersV2: validateT2IAdaptersV2, ipAdapterV2: validateIPAdapterV2, ipAdaptersV2: validateIPAdaptersV2, + layer: validateLayer, layers: validateLayers, } as const;