feat(ui): control adapter recall for control layers

- Add set of metadata handlers for the control layers CAs
- Use these conditionally depending on the active tab - when recalling on txt2img, the CAs go to control layers, else they go to the old CA area.
This commit is contained in:
psychedelicious 2024-05-02 17:41:02 +10:00 committed by Kent Keirsey
parent 4cd78b9478
commit 6363095b29
13 changed files with 654 additions and 14 deletions

View File

@ -340,6 +340,12 @@ export const controlLayersSlice = createSlice({
const layer = selectCALayerOrThrow(state, layerId);
layer.controlAdapter.isProcessingImage = isProcessingImage;
},
caLayerControlNetsDeleted: (state) => {
state.layers = state.layers.filter((l) => !isControlAdapterLayer(l) || l.controlAdapter.type !== 'controlnet');
},
caLayerT2IAdaptersDeleted: (state) => {
state.layers = state.layers.filter((l) => !isControlAdapterLayer(l) || l.controlAdapter.type !== 't2i_adapter');
},
//#endregion
//#region IP Adapter Layers
@ -389,6 +395,9 @@ export const controlLayersSlice = createSlice({
const layer = selectIPALayerOrThrow(state, layerId);
layer.ipAdapter.clipVisionModel = clipVisionModel;
},
ipaLayersDeleted: (state) => {
state.layers = state.layers.filter((l) => !isIPAdapterLayer(l));
},
//#endregion
//#region CA or IPA Layers
@ -741,12 +750,15 @@ export const {
caLayerIsFilterEnabledChanged,
caLayerOpacityChanged,
caLayerIsProcessingImageChanged,
caLayerControlNetsDeleted,
caLayerT2IAdaptersDeleted,
// IPA Layers
ipaLayerAdded,
ipaLayerImageChanged,
ipaLayerMethodChanged,
ipaLayerModelChanged,
ipaLayerCLIPVisionModelChanged,
ipaLayersDeleted,
// CA or IPA Layers
caOrIPALayerWeightChanged,
caOrIPALayerBeginEndStepPctChanged,

View File

@ -405,7 +405,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
},
};
const initialControlNetV2: Omit<ControlNetConfigV2, 'id'> = {
export const initialControlNetV2: Omit<ControlNetConfigV2, 'id'> = {
type: 'controlnet',
model: null,
weight: 1,
@ -417,7 +417,7 @@ const initialControlNetV2: Omit<ControlNetConfigV2, 'id'> = {
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
};
const initialT2IAdapterV2: Omit<T2IAdapterConfigV2, 'id'> = {
export const initialT2IAdapterV2: Omit<T2IAdapterConfigV2, 'id'> = {
type: 't2i_adapter',
model: null,
weight: 1,
@ -428,7 +428,7 @@ const initialT2IAdapterV2: Omit<T2IAdapterConfigV2, 'id'> = {
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
};
const initialIPAdapterV2: Omit<IPAdapterConfigV2, 'id'> = {
export const initialIPAdapterV2: Omit<IPAdapterConfigV2, 'id'> = {
type: 'ip_adapter',
image: null,
model: null,

View File

@ -1,9 +1,14 @@
import { useAppSelector } from 'app/store/storeHooks';
import { MetadataControlNets } from 'features/metadata/components/MetadataControlNets';
import { MetadataControlNetsV2 } from 'features/metadata/components/MetadataControlNetsV2';
import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters';
import { MetadataIPAdaptersV2 } from 'features/metadata/components/MetadataIPAdaptersV2';
import { MetadataItem } from 'features/metadata/components/MetadataItem';
import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs';
import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters';
import { MetadataT2IAdaptersV2 } from 'features/metadata/components/MetadataT2IAdaptersV2';
import { handlers } from 'features/metadata/util/handlers';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
type Props = {
@ -11,6 +16,7 @@ type Props = {
};
const ImageMetadataActions = (props: Props) => {
const activeTabName = useAppSelector(activeTabNameSelector);
const { metadata } = props;
if (!metadata || Object.keys(metadata).length === 0) {
@ -46,9 +52,12 @@ const ImageMetadataActions = (props: Props) => {
<MetadataItem metadata={metadata} handlers={handlers.refinerStart} />
<MetadataItem metadata={metadata} handlers={handlers.refinerSteps} />
<MetadataLoRAs metadata={metadata} />
<MetadataControlNets metadata={metadata} />
<MetadataT2IAdapters metadata={metadata} />
<MetadataIPAdapters metadata={metadata} />
{activeTabName !== 'txt2img' && <MetadataControlNets metadata={metadata} />}
{activeTabName !== 'txt2img' && <MetadataT2IAdapters metadata={metadata} />}
{activeTabName !== 'txt2img' && <MetadataIPAdapters metadata={metadata} />}
{activeTabName === 'txt2img' && <MetadataControlNetsV2 metadata={metadata} />}
{activeTabName === 'txt2img' && <MetadataT2IAdaptersV2 metadata={metadata} />}
{activeTabName === 'txt2img' && <MetadataIPAdaptersV2 metadata={metadata} />}
</>
);
};

View File

@ -1,8 +1,11 @@
import { useAppSelector } from 'app/store/storeHooks';
import { handlers, parseAndRecallAllMetadata, parseAndRecallPrompts } from 'features/metadata/util/handlers';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useCallback, useEffect, useState } from 'react';
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
export const useImageActions = (image_name?: string) => {
const activeTabName = useAppSelector(activeTabNameSelector);
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(image_name);
const [hasMetadata, setHasMetadata] = useState(false);
const [hasSeed, setHasSeed] = useState(false);
@ -40,13 +43,13 @@ export const useImageActions = (image_name?: string) => {
}, [metadata]);
const recallAll = useCallback(() => {
parseAndRecallAllMetadata(metadata);
}, [metadata]);
parseAndRecallAllMetadata(metadata, activeTabName === 'txt2img');
}, [activeTabName, metadata]);
const remix = useCallback(() => {
// Recalls all metadata parameters except seed
parseAndRecallAllMetadata(metadata, ['seed']);
}, [metadata]);
parseAndRecallAllMetadata(metadata, activeTabName === 'txt2img', ['seed']);
}, [activeTabName, metadata]);
const recallSeed = useCallback(() => {
handlers.seed.parse(metadata).then((seed) => {

View File

@ -0,0 +1,72 @@
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { ControlNetConfigV2Metadata, MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
type Props = {
metadata: unknown;
};
export const MetadataControlNetsV2 = ({ metadata }: Props) => {
const [controlNets, setControlNets] = useState<ControlNetConfigV2Metadata[]>([]);
useEffect(() => {
const parse = async () => {
try {
const parsed = await handlers.controlNetsV2.parse(metadata);
setControlNets(parsed);
} catch (e) {
setControlNets([]);
}
};
parse();
}, [metadata]);
const label = useMemo(() => handlers.controlNetsV2.getLabel(), []);
return (
<>
{controlNets.map((controlNet) => (
<MetadataViewControlNet
key={controlNet.id}
label={label}
controlNet={controlNet}
handlers={handlers.controlNetsV2}
/>
))}
</>
);
};
const MetadataViewControlNet = ({
label,
controlNet,
handlers,
}: {
label: string;
controlNet: ControlNetConfigV2Metadata;
handlers: MetadataHandlers<ControlNetConfigV2Metadata[], ControlNetConfigV2Metadata>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {
return;
}
handlers.recallItem(controlNet, true);
}, [handlers, controlNet]);
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
useEffect(() => {
const _renderValue = async () => {
if (!handlers.renderItemValue) {
setRenderedValue(null);
return;
}
const rendered = await handlers.renderItemValue(controlNet);
setRenderedValue(rendered);
};
_renderValue();
}, [handlers, controlNet]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
};

View File

@ -0,0 +1,72 @@
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { IPAdapterConfigV2Metadata, MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
type Props = {
metadata: unknown;
};
export const MetadataIPAdaptersV2 = ({ metadata }: Props) => {
const [ipAdapters, setIPAdapters] = useState<IPAdapterConfigV2Metadata[]>([]);
useEffect(() => {
const parse = async () => {
try {
const parsed = await handlers.ipAdaptersV2.parse(metadata);
setIPAdapters(parsed);
} catch (e) {
setIPAdapters([]);
}
};
parse();
}, [metadata]);
const label = useMemo(() => handlers.ipAdaptersV2.getLabel(), []);
return (
<>
{ipAdapters.map((ipAdapter) => (
<MetadataViewIPAdapter
key={ipAdapter.id}
label={label}
ipAdapter={ipAdapter}
handlers={handlers.ipAdaptersV2}
/>
))}
</>
);
};
const MetadataViewIPAdapter = ({
label,
ipAdapter,
handlers,
}: {
label: string;
ipAdapter: IPAdapterConfigV2Metadata;
handlers: MetadataHandlers<IPAdapterConfigV2Metadata[], IPAdapterConfigV2Metadata>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {
return;
}
handlers.recallItem(ipAdapter, true);
}, [handlers, ipAdapter]);
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
useEffect(() => {
const _renderValue = async () => {
if (!handlers.renderItemValue) {
setRenderedValue(null);
return;
}
const rendered = await handlers.renderItemValue(ipAdapter);
setRenderedValue(rendered);
};
_renderValue();
}, [handlers, ipAdapter]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
};

View File

@ -0,0 +1,72 @@
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers, T2IAdapterConfigV2Metadata } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
type Props = {
metadata: unknown;
};
export const MetadataT2IAdaptersV2 = ({ metadata }: Props) => {
const [t2iAdapters, setT2IAdapters] = useState<T2IAdapterConfigV2Metadata[]>([]);
useEffect(() => {
const parse = async () => {
try {
const parsed = await handlers.t2iAdaptersV2.parse(metadata);
setT2IAdapters(parsed);
} catch (e) {
setT2IAdapters([]);
}
};
parse();
}, [metadata]);
const label = useMemo(() => handlers.t2iAdaptersV2.getLabel(), []);
return (
<>
{t2iAdapters.map((t2iAdapter) => (
<MetadataViewT2IAdapter
key={t2iAdapter.id}
label={label}
t2iAdapter={t2iAdapter}
handlers={handlers.t2iAdaptersV2}
/>
))}
</>
);
};
const MetadataViewT2IAdapter = ({
label,
t2iAdapter,
handlers,
}: {
label: string;
t2iAdapter: T2IAdapterConfigV2Metadata;
handlers: MetadataHandlers<T2IAdapterConfigV2Metadata[], T2IAdapterConfigV2Metadata>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {
return;
}
handlers.recallItem(t2iAdapter, true);
}, [handlers, t2iAdapter]);
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
useEffect(() => {
const _renderValue = async () => {
if (!handlers.renderItemValue) {
setRenderedValue(null);
return;
}
const rendered = await handlers.renderItemValue(t2iAdapter);
setRenderedValue(rendered);
};
_renderValue();
}, [handlers, t2iAdapter]);
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
};

View File

@ -1,4 +1,5 @@
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { ControlNetConfigV2, IPAdapterConfigV2, T2IAdapterConfigV2 } from 'features/controlLayers/util/controlAdapters';
import type { O } from 'ts-toolbelt';
/**
@ -135,3 +136,11 @@ export type AnyControlAdapterConfigMetadata =
| ControlNetConfigMetadata
| T2IAdapterConfigMetadata
| IPAdapterConfigMetadata;
export type ControlNetConfigV2Metadata = O.NonNullable<ControlNetConfigV2, 'model'>;
export type T2IAdapterConfigV2Metadata = O.NonNullable<T2IAdapterConfigV2, 'model'>;
export type IPAdapterConfigV2Metadata = O.NonNullable<IPAdapterConfigV2, 'model'>;
export type AnyControlAdapterConfigV2Metadata =
| ControlNetConfigV2Metadata
| T2IAdapterConfigV2Metadata
| IPAdapterConfigV2Metadata;

View File

@ -3,6 +3,7 @@ import { toast } from 'common/util/toast';
import type { LoRA } from 'features/lora/store/loraSlice';
import type {
AnyControlAdapterConfigMetadata,
AnyControlAdapterConfigV2Metadata,
BuildMetadataHandlers,
MetadataGetLabelFunc,
MetadataHandlers,
@ -43,6 +44,14 @@ const renderControlAdapterValue: MetadataRenderValueFunc<AnyControlAdapterConfig
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
}
};
const renderControlAdapterValueV2: MetadataRenderValueFunc<AnyControlAdapterConfigV2Metadata> = async (value) => {
try {
const modelConfig = await fetchModelConfig(value.model.key ?? 'none');
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
} catch {
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
}
};
const parameterSetToast = (parameter: string, description?: string) => {
toast({
@ -341,6 +350,36 @@ export const handlers = {
itemValidator: validators.t2iAdapter,
renderItemValue: renderControlAdapterValue,
}),
controlNetsV2: buildHandlers({
getLabel: () => t('common.controlNet'),
parser: parsers.controlNetsV2,
itemParser: parsers.controlNetV2,
recaller: recallers.controlNetsV2,
itemRecaller: recallers.controlNetV2,
validator: validators.controlNetsV2,
itemValidator: validators.controlNetV2,
renderItemValue: renderControlAdapterValueV2,
}),
ipAdaptersV2: buildHandlers({
getLabel: () => t('common.ipAdapter'),
parser: parsers.ipAdaptersV2,
itemParser: parsers.ipAdapterV2,
recaller: recallers.ipAdaptersV2,
itemRecaller: recallers.ipAdapterV2,
validator: validators.ipAdaptersV2,
itemValidator: validators.ipAdapterV2,
renderItemValue: renderControlAdapterValueV2,
}),
t2iAdaptersV2: buildHandlers({
getLabel: () => t('common.t2iAdapter'),
parser: parsers.t2iAdaptersV2,
itemParser: parsers.t2iAdapterV2,
recaller: recallers.t2iAdaptersV2,
itemRecaller: recallers.t2iAdapterV2,
validator: validators.t2iAdaptersV2,
itemValidator: validators.t2iAdapterV2,
renderItemValue: renderControlAdapterValueV2,
}),
} as const;
export const parseAndRecallPrompts = async (metadata: unknown) => {
@ -395,10 +434,25 @@ export const parseAndRecallImageDimensions = async (metadata: unknown) => {
}
};
export const parseAndRecallAllMetadata = async (metadata: unknown, skip: (keyof typeof handlers)[] = []) => {
// These handlers should be omitted when recalling to control layers
const TO_CONTROL_LAYERS_SKIP_KEYS: (keyof typeof handlers)[] = ['controlNets', 'ipAdapters', 't2iAdapters'];
// These handlers should be omitted when recalling to the rest of the app
const NOT_TO_CONTROL_LAYERS_SKIP_KEYS: (keyof typeof handlers)[] = ['controlNetsV2', 'ipAdaptersV2', 't2iAdaptersV2'];
export const parseAndRecallAllMetadata = async (
metadata: unknown,
toControlLayers: boolean,
skip: (keyof typeof handlers)[] = []
) => {
const skipKeys = skip ?? [];
if (toControlLayers) {
skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS);
} else {
skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS);
}
const results = await Promise.allSettled(
objectKeys(handlers)
.filter((key) => !skip.includes(key))
.filter((key) => !skipKeys.includes(key))
.map((key) => {
const { parse, recall } = handlers[key];
return parse(metadata).then((value) => {

View File

@ -5,13 +5,24 @@ import {
initialT2IAdapter,
} from 'features/controlAdapters/util/buildControlAdapter';
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
import {
CA_PROCESSOR_DATA,
imageDTOToImageWithDims,
initialControlNetV2,
initialIPAdapterV2,
initialT2IAdapterV2,
isProcessorTypeV2,
} from 'features/controlLayers/util/controlAdapters';
import type { LoRA } from 'features/lora/store/loraSlice';
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
import type {
ControlNetConfigMetadata,
ControlNetConfigV2Metadata,
IPAdapterConfigMetadata,
IPAdapterConfigV2Metadata,
MetadataParseFunc,
T2IAdapterConfigMetadata,
T2IAdapterConfigV2Metadata,
} from 'features/metadata/types';
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common';
@ -58,7 +69,7 @@ import {
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import { get, isArray, isString } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import {
isControlNetModelConfig,
@ -428,6 +439,203 @@ const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (
}
};
//#region V2/Control Layers
const parseControlNetV2: MetadataParseFunc<ControlNetConfigV2Metadata> = async (metadataItem) => {
const control_model = await getProperty(metadataItem, 'control_model');
const key = await getModelKey(control_model, 'controlnet');
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
const image = zControlField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const processedImage = zControlField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'processed_image'));
const control_weight = zControlField.shape.control_weight
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'control_weight'));
const begin_step_percent = zControlField.shape.begin_step_percent
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'begin_step_percent'));
const end_step_percent = zControlField.shape.end_step_percent
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'end_step_percent'));
const control_mode = zControlField.shape.control_mode
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'control_mode'));
const id = uuidv4();
const defaultPreprocessor = controlNetModel.default_settings?.preprocessor;
const processorConfig = isProcessorTypeV2(defaultPreprocessor)
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults()
: null;
const beginEndStepPct: [number, number] = [
begin_step_percent ?? initialControlNetV2.beginEndStepPct[0],
end_step_percent ?? initialControlNetV2.beginEndStepPct[1],
];
const imageDTO = image ? await getImageDTO(image.image_name) : null;
const processedImageDTO = processedImage ? await getImageDTO(processedImage.image_name) : null;
const controlNet: ControlNetConfigV2Metadata = {
id,
type: 'controlnet',
model: zModelIdentifierField.parse(controlNetModel),
weight: typeof control_weight === 'number' ? control_weight : initialControlNetV2.weight,
beginEndStepPct,
controlMode: control_mode ?? initialControlNetV2.controlMode,
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
processedImage: processedImageDTO ? imageDTOToImageWithDims(processedImageDTO) : null,
processorConfig,
isProcessingImage: false,
};
return controlNet;
};
const parseAllControlNetsV2: MetadataParseFunc<ControlNetConfigV2Metadata[]> = async (metadata) => {
try {
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray || undefined);
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNetV2(cn)));
const controlNets = parseResults
.filter((result): result is PromiseFulfilledResult<ControlNetConfigV2Metadata> => result.status === 'fulfilled')
.map((result) => result.value);
return controlNets;
} catch {
return [];
}
};
const parseT2IAdapterV2: MetadataParseFunc<T2IAdapterConfigV2Metadata> = async (metadataItem) => {
const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model');
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
const image = zT2IAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const processedImage = zT2IAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'processed_image'));
const weight = zT2IAdapterField.shape.weight
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'begin_step_percent'));
const end_step_percent = zT2IAdapterField.shape.end_step_percent
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'end_step_percent'));
const id = uuidv4();
const defaultPreprocessor = t2iAdapterModel.default_settings?.preprocessor;
const processorConfig = isProcessorTypeV2(defaultPreprocessor)
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults()
: null;
const beginEndStepPct: [number, number] = [
begin_step_percent ?? initialT2IAdapterV2.beginEndStepPct[0],
end_step_percent ?? initialT2IAdapterV2.beginEndStepPct[1],
];
const imageDTO = image ? await getImageDTO(image.image_name) : null;
const processedImageDTO = processedImage ? await getImageDTO(processedImage.image_name) : null;
const t2iAdapter: T2IAdapterConfigV2Metadata = {
id,
type: 't2i_adapter',
model: zModelIdentifierField.parse(t2iAdapterModel),
weight: typeof weight === 'number' ? weight : initialT2IAdapterV2.weight,
beginEndStepPct,
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
processedImage: processedImageDTO ? imageDTOToImageWithDims(processedImageDTO) : null,
processorConfig,
isProcessingImage: false,
};
return t2iAdapter;
};
const parseAllT2IAdaptersV2: MetadataParseFunc<T2IAdapterConfigV2Metadata[]> = async (metadata) => {
try {
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapterV2(t2iAdapter)));
const t2iAdapters = parseResults
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfigV2Metadata> => result.status === 'fulfilled')
.map((result) => result.value);
return t2iAdapters;
} catch {
return [];
}
};
const parseIPAdapterV2: MetadataParseFunc<IPAdapterConfigV2Metadata> = async (metadataItem) => {
const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model');
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
const image = zIPAdapterField.shape.image
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'image'));
const weight = zIPAdapterField.shape.weight
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'weight'));
const method = zIPAdapterField.shape.method
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'method'));
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'begin_step_percent'));
const end_step_percent = zIPAdapterField.shape.end_step_percent
.nullish()
.catch(null)
.parse(await getProperty(metadataItem, 'end_step_percent'));
const id = uuidv4();
const beginEndStepPct: [number, number] = [
begin_step_percent ?? initialIPAdapterV2.beginEndStepPct[0],
end_step_percent ?? initialIPAdapterV2.beginEndStepPct[1],
];
const imageDTO = image ? await getImageDTO(image.image_name) : null;
const ipAdapter: IPAdapterConfigV2Metadata = {
id,
type: 'ip_adapter',
model: zModelIdentifierField.parse(ipAdapterModel),
weight: typeof weight === 'number' ? weight : initialIPAdapterV2.weight,
beginEndStepPct,
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
clipVisionModel: initialIPAdapterV2.clipVisionModel, // TODO: This needs to be added to the zIPAdapterField...
method: method ?? initialIPAdapterV2.method,
};
return ipAdapter;
};
const parseAllIPAdaptersV2: MetadataParseFunc<IPAdapterConfigV2Metadata[]> = async (metadata) => {
try {
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapterV2(ipAdapter)));
const ipAdapters = parseResults
.filter((result): result is PromiseFulfilledResult<IPAdapterConfigV2Metadata> => result.status === 'fulfilled')
.map((result) => result.value);
return ipAdapters;
} catch {
return [];
}
};
export const parsers = {
createdBy: parseCreatedBy,
generationMode: parseGenerationMode,
@ -464,4 +672,10 @@ export const parsers = {
t2iAdapters: parseAllT2IAdapters,
ipAdapter: parseIPAdapter,
ipAdapters: parseAllIPAdapters,
controlNetV2: parseControlNetV2,
controlNetsV2: parseAllControlNetsV2,
t2iAdapterV2: parseT2IAdapterV2,
t2iAdaptersV2: parseAllT2IAdaptersV2,
ipAdapterV2: parseIPAdapterV2,
ipAdaptersV2: parseAllIPAdaptersV2,
} as const;

View File

@ -6,7 +6,12 @@ import {
t2iAdaptersReset,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import {
caLayerAdded,
caLayerControlNetsDeleted,
caLayerT2IAdaptersDeleted,
heightChanged,
ipaLayerAdded,
ipaLayersDeleted,
negativePrompt2Changed,
negativePromptChanged,
positivePrompt2Changed,
@ -18,9 +23,12 @@ import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRecalled, lorasReset } from 'features/lora/store/loraSlice';
import type {
ControlNetConfigMetadata,
ControlNetConfigV2Metadata,
IPAdapterConfigMetadata,
IPAdapterConfigV2Metadata,
MetadataRecallFunc,
T2IAdapterConfigMetadata,
T2IAdapterConfigV2Metadata,
} from 'features/metadata/types';
import { modelSelected } from 'features/parameters/store/actions';
import {
@ -234,6 +242,52 @@ const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapt
});
};
//#region V2/Control Layer
const recallControlNetV2: MetadataRecallFunc<ControlNetConfigV2Metadata> = (controlNet) => {
getStore().dispatch(caLayerAdded(controlNet));
};
const recallControlNetsV2: MetadataRecallFunc<ControlNetConfigV2Metadata[]> = (controlNets) => {
const { dispatch } = getStore();
dispatch(caLayerControlNetsDeleted());
if (!controlNets.length) {
return;
}
controlNets.forEach((controlNet) => {
dispatch(caLayerAdded(controlNet));
});
};
const recallT2IAdapterV2: MetadataRecallFunc<T2IAdapterConfigV2Metadata> = (t2iAdapter) => {
getStore().dispatch(caLayerAdded(t2iAdapter));
};
const recallT2IAdaptersV2: MetadataRecallFunc<T2IAdapterConfigV2Metadata[]> = (t2iAdapters) => {
const { dispatch } = getStore();
dispatch(caLayerT2IAdaptersDeleted());
if (!t2iAdapters.length) {
return;
}
t2iAdapters.forEach((t2iAdapters) => {
dispatch(caLayerAdded(t2iAdapters));
});
};
const recallIPAdapterV2: MetadataRecallFunc<IPAdapterConfigV2Metadata> = (ipAdapter) => {
getStore().dispatch(ipaLayerAdded(ipAdapter));
};
const recallIPAdaptersV2: MetadataRecallFunc<IPAdapterConfigV2Metadata[]> = (ipAdapters) => {
const { dispatch } = getStore();
dispatch(ipaLayersDeleted());
if (!ipAdapters.length) {
return;
}
ipAdapters.forEach((ipAdapter) => {
dispatch(ipaLayerAdded(ipAdapter));
});
};
export const recallers = {
positivePrompt: recallPositivePrompt,
negativePrompt: recallNegativePrompt,
@ -268,4 +322,10 @@ export const recallers = {
t2iAdapter: recallT2IAdapter,
ipAdapters: recallIPAdapters,
ipAdapter: recallIPAdapter,
controlNetV2: recallControlNetV2,
controlNetsV2: recallControlNetsV2,
t2iAdapterV2: recallT2IAdapterV2,
t2iAdaptersV2: recallT2IAdaptersV2,
ipAdapterV2: recallIPAdapterV2,
ipAdaptersV2: recallIPAdaptersV2,
} as const;

View File

@ -2,9 +2,12 @@ import { getStore } from 'app/store/nanostores/store';
import type { LoRA } from 'features/lora/store/loraSlice';
import type {
ControlNetConfigMetadata,
ControlNetConfigV2Metadata,
IPAdapterConfigMetadata,
IPAdapterConfigV2Metadata,
MetadataValidateFunc,
T2IAdapterConfigMetadata,
T2IAdapterConfigV2Metadata,
} from 'features/metadata/types';
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
@ -108,6 +111,60 @@ const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipA
return new Promise((resolve) => resolve(validatedIPAdapters));
};
const validateControlNetV2: MetadataValidateFunc<ControlNetConfigV2Metadata> = (controlNet) => {
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
return new Promise((resolve) => resolve(controlNet));
};
const validateControlNetsV2: MetadataValidateFunc<ControlNetConfigV2Metadata[]> = (controlNets) => {
const validatedControlNets: ControlNetConfigV2Metadata[] = [];
controlNets.forEach((controlNet) => {
try {
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
validatedControlNets.push(controlNet);
} catch {
// This is a no-op - we want to continue validating the rest of the ControlNets, and an empty list is valid.
}
});
return new Promise((resolve) => resolve(validatedControlNets));
};
const validateT2IAdapterV2: MetadataValidateFunc<T2IAdapterConfigV2Metadata> = (t2iAdapter) => {
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
return new Promise((resolve) => resolve(t2iAdapter));
};
const validateT2IAdaptersV2: MetadataValidateFunc<T2IAdapterConfigV2Metadata[]> = (t2iAdapters) => {
const validatedT2IAdapters: T2IAdapterConfigV2Metadata[] = [];
t2iAdapters.forEach((t2iAdapter) => {
try {
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
validatedT2IAdapters.push(t2iAdapter);
} catch {
// This is a no-op - we want to continue validating the rest of the T2I Adapters, and an empty list is valid.
}
});
return new Promise((resolve) => resolve(validatedT2IAdapters));
};
const validateIPAdapterV2: MetadataValidateFunc<IPAdapterConfigV2Metadata> = (ipAdapter) => {
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
return new Promise((resolve) => resolve(ipAdapter));
};
const validateIPAdaptersV2: MetadataValidateFunc<IPAdapterConfigV2Metadata[]> = (ipAdapters) => {
const validatedIPAdapters: IPAdapterConfigV2Metadata[] = [];
ipAdapters.forEach((ipAdapter) => {
try {
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
validatedIPAdapters.push(ipAdapter);
} catch {
// This is a no-op - we want to continue validating the rest of the IP Adapters, and an empty list is valid.
}
});
return new Promise((resolve) => resolve(validatedIPAdapters));
};
export const validators = {
refinerModel: validateRefinerModel,
vaeModel: validateVAEModel,
@ -119,4 +176,10 @@ export const validators = {
t2iAdapters: validateT2IAdapters,
ipAdapter: validateIPAdapter,
ipAdapters: validateIPAdapters,
controlNetV2: validateControlNetV2,
controlNetsV2: validateControlNetsV2,
t2iAdapterV2: validateT2IAdapterV2,
t2iAdaptersV2: validateT2IAdaptersV2,
ipAdapterV2: validateIPAdapterV2,
ipAdaptersV2: validateIPAdaptersV2,
} as const;

View File

@ -43,7 +43,7 @@ export const usePreselectedImage = (selectedImage?: {
const handleUseAllMetadata = useCallback(() => {
if (selectedImageMetadata) {
parseAndRecallAllMetadata(selectedImageMetadata);
parseAndRecallAllMetadata(selectedImageMetadata, true);
}
}, [selectedImageMetadata]);