mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): control adapter recall for control layers
- Add set of metadata handlers for the control layers CAs - Use these conditionally depending on the active tab - when recalling on txt2img, the CAs go to control layers, else they go to the old CA area.
This commit is contained in:
parent
4cd78b9478
commit
6363095b29
@ -340,6 +340,12 @@ export const controlLayersSlice = createSlice({
|
||||
const layer = selectCALayerOrThrow(state, layerId);
|
||||
layer.controlAdapter.isProcessingImage = isProcessingImage;
|
||||
},
|
||||
caLayerControlNetsDeleted: (state) => {
|
||||
state.layers = state.layers.filter((l) => !isControlAdapterLayer(l) || l.controlAdapter.type !== 'controlnet');
|
||||
},
|
||||
caLayerT2IAdaptersDeleted: (state) => {
|
||||
state.layers = state.layers.filter((l) => !isControlAdapterLayer(l) || l.controlAdapter.type !== 't2i_adapter');
|
||||
},
|
||||
//#endregion
|
||||
|
||||
//#region IP Adapter Layers
|
||||
@ -389,6 +395,9 @@ export const controlLayersSlice = createSlice({
|
||||
const layer = selectIPALayerOrThrow(state, layerId);
|
||||
layer.ipAdapter.clipVisionModel = clipVisionModel;
|
||||
},
|
||||
ipaLayersDeleted: (state) => {
|
||||
state.layers = state.layers.filter((l) => !isIPAdapterLayer(l));
|
||||
},
|
||||
//#endregion
|
||||
|
||||
//#region CA or IPA Layers
|
||||
@ -741,12 +750,15 @@ export const {
|
||||
caLayerIsFilterEnabledChanged,
|
||||
caLayerOpacityChanged,
|
||||
caLayerIsProcessingImageChanged,
|
||||
caLayerControlNetsDeleted,
|
||||
caLayerT2IAdaptersDeleted,
|
||||
// IPA Layers
|
||||
ipaLayerAdded,
|
||||
ipaLayerImageChanged,
|
||||
ipaLayerMethodChanged,
|
||||
ipaLayerModelChanged,
|
||||
ipaLayerCLIPVisionModelChanged,
|
||||
ipaLayersDeleted,
|
||||
// CA or IPA Layers
|
||||
caOrIPALayerWeightChanged,
|
||||
caOrIPALayerBeginEndStepPctChanged,
|
||||
|
@ -405,7 +405,7 @@ export const CA_PROCESSOR_DATA: CAProcessorsData = {
|
||||
},
|
||||
};
|
||||
|
||||
const initialControlNetV2: Omit<ControlNetConfigV2, 'id'> = {
|
||||
export const initialControlNetV2: Omit<ControlNetConfigV2, 'id'> = {
|
||||
type: 'controlnet',
|
||||
model: null,
|
||||
weight: 1,
|
||||
@ -417,7 +417,7 @@ const initialControlNetV2: Omit<ControlNetConfigV2, 'id'> = {
|
||||
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
|
||||
};
|
||||
|
||||
const initialT2IAdapterV2: Omit<T2IAdapterConfigV2, 'id'> = {
|
||||
export const initialT2IAdapterV2: Omit<T2IAdapterConfigV2, 'id'> = {
|
||||
type: 't2i_adapter',
|
||||
model: null,
|
||||
weight: 1,
|
||||
@ -428,7 +428,7 @@ const initialT2IAdapterV2: Omit<T2IAdapterConfigV2, 'id'> = {
|
||||
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
|
||||
};
|
||||
|
||||
const initialIPAdapterV2: Omit<IPAdapterConfigV2, 'id'> = {
|
||||
export const initialIPAdapterV2: Omit<IPAdapterConfigV2, 'id'> = {
|
||||
type: 'ip_adapter',
|
||||
image: null,
|
||||
model: null,
|
||||
|
@ -1,9 +1,14 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { MetadataControlNets } from 'features/metadata/components/MetadataControlNets';
|
||||
import { MetadataControlNetsV2 } from 'features/metadata/components/MetadataControlNetsV2';
|
||||
import { MetadataIPAdapters } from 'features/metadata/components/MetadataIPAdapters';
|
||||
import { MetadataIPAdaptersV2 } from 'features/metadata/components/MetadataIPAdaptersV2';
|
||||
import { MetadataItem } from 'features/metadata/components/MetadataItem';
|
||||
import { MetadataLoRAs } from 'features/metadata/components/MetadataLoRAs';
|
||||
import { MetadataT2IAdapters } from 'features/metadata/components/MetadataT2IAdapters';
|
||||
import { MetadataT2IAdaptersV2 } from 'features/metadata/components/MetadataT2IAdaptersV2';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
type Props = {
|
||||
@ -11,6 +16,7 @@ type Props = {
|
||||
};
|
||||
|
||||
const ImageMetadataActions = (props: Props) => {
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const { metadata } = props;
|
||||
|
||||
if (!metadata || Object.keys(metadata).length === 0) {
|
||||
@ -46,9 +52,12 @@ const ImageMetadataActions = (props: Props) => {
|
||||
<MetadataItem metadata={metadata} handlers={handlers.refinerStart} />
|
||||
<MetadataItem metadata={metadata} handlers={handlers.refinerSteps} />
|
||||
<MetadataLoRAs metadata={metadata} />
|
||||
<MetadataControlNets metadata={metadata} />
|
||||
<MetadataT2IAdapters metadata={metadata} />
|
||||
<MetadataIPAdapters metadata={metadata} />
|
||||
{activeTabName !== 'txt2img' && <MetadataControlNets metadata={metadata} />}
|
||||
{activeTabName !== 'txt2img' && <MetadataT2IAdapters metadata={metadata} />}
|
||||
{activeTabName !== 'txt2img' && <MetadataIPAdapters metadata={metadata} />}
|
||||
{activeTabName === 'txt2img' && <MetadataControlNetsV2 metadata={metadata} />}
|
||||
{activeTabName === 'txt2img' && <MetadataT2IAdaptersV2 metadata={metadata} />}
|
||||
{activeTabName === 'txt2img' && <MetadataIPAdaptersV2 metadata={metadata} />}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -1,8 +1,11 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { handlers, parseAndRecallAllMetadata, parseAndRecallPrompts } from 'features/metadata/util/handlers';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useDebouncedMetadata } from 'services/api/hooks/useDebouncedMetadata';
|
||||
|
||||
export const useImageActions = (image_name?: string) => {
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const { metadata, isLoading: isLoadingMetadata } = useDebouncedMetadata(image_name);
|
||||
const [hasMetadata, setHasMetadata] = useState(false);
|
||||
const [hasSeed, setHasSeed] = useState(false);
|
||||
@ -40,13 +43,13 @@ export const useImageActions = (image_name?: string) => {
|
||||
}, [metadata]);
|
||||
|
||||
const recallAll = useCallback(() => {
|
||||
parseAndRecallAllMetadata(metadata);
|
||||
}, [metadata]);
|
||||
parseAndRecallAllMetadata(metadata, activeTabName === 'txt2img');
|
||||
}, [activeTabName, metadata]);
|
||||
|
||||
const remix = useCallback(() => {
|
||||
// Recalls all metadata parameters except seed
|
||||
parseAndRecallAllMetadata(metadata, ['seed']);
|
||||
}, [metadata]);
|
||||
parseAndRecallAllMetadata(metadata, activeTabName === 'txt2img', ['seed']);
|
||||
}, [activeTabName, metadata]);
|
||||
|
||||
const recallSeed = useCallback(() => {
|
||||
handlers.seed.parse(metadata).then((seed) => {
|
||||
|
@ -0,0 +1,72 @@
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { ControlNetConfigV2Metadata, MetadataHandlers } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type Props = {
|
||||
metadata: unknown;
|
||||
};
|
||||
|
||||
export const MetadataControlNetsV2 = ({ metadata }: Props) => {
|
||||
const [controlNets, setControlNets] = useState<ControlNetConfigV2Metadata[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
try {
|
||||
const parsed = await handlers.controlNetsV2.parse(metadata);
|
||||
setControlNets(parsed);
|
||||
} catch (e) {
|
||||
setControlNets([]);
|
||||
}
|
||||
};
|
||||
parse();
|
||||
}, [metadata]);
|
||||
|
||||
const label = useMemo(() => handlers.controlNetsV2.getLabel(), []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{controlNets.map((controlNet) => (
|
||||
<MetadataViewControlNet
|
||||
key={controlNet.id}
|
||||
label={label}
|
||||
controlNet={controlNet}
|
||||
handlers={handlers.controlNetsV2}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const MetadataViewControlNet = ({
|
||||
label,
|
||||
controlNet,
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
controlNet: ControlNetConfigV2Metadata;
|
||||
handlers: MetadataHandlers<ControlNetConfigV2Metadata[], ControlNetConfigV2Metadata>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
return;
|
||||
}
|
||||
handlers.recallItem(controlNet, true);
|
||||
}, [handlers, controlNet]);
|
||||
|
||||
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||
useEffect(() => {
|
||||
const _renderValue = async () => {
|
||||
if (!handlers.renderItemValue) {
|
||||
setRenderedValue(null);
|
||||
return;
|
||||
}
|
||||
const rendered = await handlers.renderItemValue(controlNet);
|
||||
setRenderedValue(rendered);
|
||||
};
|
||||
|
||||
_renderValue();
|
||||
}, [handlers, controlNet]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { IPAdapterConfigV2Metadata, MetadataHandlers } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type Props = {
|
||||
metadata: unknown;
|
||||
};
|
||||
|
||||
export const MetadataIPAdaptersV2 = ({ metadata }: Props) => {
|
||||
const [ipAdapters, setIPAdapters] = useState<IPAdapterConfigV2Metadata[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
try {
|
||||
const parsed = await handlers.ipAdaptersV2.parse(metadata);
|
||||
setIPAdapters(parsed);
|
||||
} catch (e) {
|
||||
setIPAdapters([]);
|
||||
}
|
||||
};
|
||||
parse();
|
||||
}, [metadata]);
|
||||
|
||||
const label = useMemo(() => handlers.ipAdaptersV2.getLabel(), []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{ipAdapters.map((ipAdapter) => (
|
||||
<MetadataViewIPAdapter
|
||||
key={ipAdapter.id}
|
||||
label={label}
|
||||
ipAdapter={ipAdapter}
|
||||
handlers={handlers.ipAdaptersV2}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const MetadataViewIPAdapter = ({
|
||||
label,
|
||||
ipAdapter,
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
ipAdapter: IPAdapterConfigV2Metadata;
|
||||
handlers: MetadataHandlers<IPAdapterConfigV2Metadata[], IPAdapterConfigV2Metadata>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
return;
|
||||
}
|
||||
handlers.recallItem(ipAdapter, true);
|
||||
}, [handlers, ipAdapter]);
|
||||
|
||||
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||
useEffect(() => {
|
||||
const _renderValue = async () => {
|
||||
if (!handlers.renderItemValue) {
|
||||
setRenderedValue(null);
|
||||
return;
|
||||
}
|
||||
const rendered = await handlers.renderItemValue(ipAdapter);
|
||||
setRenderedValue(rendered);
|
||||
};
|
||||
|
||||
_renderValue();
|
||||
}, [handlers, ipAdapter]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
};
|
@ -0,0 +1,72 @@
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers, T2IAdapterConfigV2Metadata } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
type Props = {
|
||||
metadata: unknown;
|
||||
};
|
||||
|
||||
export const MetadataT2IAdaptersV2 = ({ metadata }: Props) => {
|
||||
const [t2iAdapters, setT2IAdapters] = useState<T2IAdapterConfigV2Metadata[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
try {
|
||||
const parsed = await handlers.t2iAdaptersV2.parse(metadata);
|
||||
setT2IAdapters(parsed);
|
||||
} catch (e) {
|
||||
setT2IAdapters([]);
|
||||
}
|
||||
};
|
||||
parse();
|
||||
}, [metadata]);
|
||||
|
||||
const label = useMemo(() => handlers.t2iAdaptersV2.getLabel(), []);
|
||||
|
||||
return (
|
||||
<>
|
||||
{t2iAdapters.map((t2iAdapter) => (
|
||||
<MetadataViewT2IAdapter
|
||||
key={t2iAdapter.id}
|
||||
label={label}
|
||||
t2iAdapter={t2iAdapter}
|
||||
handlers={handlers.t2iAdaptersV2}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
const MetadataViewT2IAdapter = ({
|
||||
label,
|
||||
t2iAdapter,
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
t2iAdapter: T2IAdapterConfigV2Metadata;
|
||||
handlers: MetadataHandlers<T2IAdapterConfigV2Metadata[], T2IAdapterConfigV2Metadata>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
return;
|
||||
}
|
||||
handlers.recallItem(t2iAdapter, true);
|
||||
}, [handlers, t2iAdapter]);
|
||||
|
||||
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||
useEffect(() => {
|
||||
const _renderValue = async () => {
|
||||
if (!handlers.renderItemValue) {
|
||||
setRenderedValue(null);
|
||||
return;
|
||||
}
|
||||
const rendered = await handlers.renderItemValue(t2iAdapter);
|
||||
setRenderedValue(rendered);
|
||||
};
|
||||
|
||||
_renderValue();
|
||||
}, [handlers, t2iAdapter]);
|
||||
|
||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||
};
|
@ -1,4 +1,5 @@
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { ControlNetConfigV2, IPAdapterConfigV2, T2IAdapterConfigV2 } from 'features/controlLayers/util/controlAdapters';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
/**
|
||||
@ -135,3 +136,11 @@ export type AnyControlAdapterConfigMetadata =
|
||||
| ControlNetConfigMetadata
|
||||
| T2IAdapterConfigMetadata
|
||||
| IPAdapterConfigMetadata;
|
||||
|
||||
export type ControlNetConfigV2Metadata = O.NonNullable<ControlNetConfigV2, 'model'>;
|
||||
export type T2IAdapterConfigV2Metadata = O.NonNullable<T2IAdapterConfigV2, 'model'>;
|
||||
export type IPAdapterConfigV2Metadata = O.NonNullable<IPAdapterConfigV2, 'model'>;
|
||||
export type AnyControlAdapterConfigV2Metadata =
|
||||
| ControlNetConfigV2Metadata
|
||||
| T2IAdapterConfigV2Metadata
|
||||
| IPAdapterConfigV2Metadata;
|
||||
|
@ -3,6 +3,7 @@ import { toast } from 'common/util/toast';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
AnyControlAdapterConfigMetadata,
|
||||
AnyControlAdapterConfigV2Metadata,
|
||||
BuildMetadataHandlers,
|
||||
MetadataGetLabelFunc,
|
||||
MetadataHandlers,
|
||||
@ -43,6 +44,14 @@ const renderControlAdapterValue: MetadataRenderValueFunc<AnyControlAdapterConfig
|
||||
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
|
||||
}
|
||||
};
|
||||
const renderControlAdapterValueV2: MetadataRenderValueFunc<AnyControlAdapterConfigV2Metadata> = async (value) => {
|
||||
try {
|
||||
const modelConfig = await fetchModelConfig(value.model.key ?? 'none');
|
||||
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
|
||||
} catch {
|
||||
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
|
||||
}
|
||||
};
|
||||
|
||||
const parameterSetToast = (parameter: string, description?: string) => {
|
||||
toast({
|
||||
@ -341,6 +350,36 @@ export const handlers = {
|
||||
itemValidator: validators.t2iAdapter,
|
||||
renderItemValue: renderControlAdapterValue,
|
||||
}),
|
||||
controlNetsV2: buildHandlers({
|
||||
getLabel: () => t('common.controlNet'),
|
||||
parser: parsers.controlNetsV2,
|
||||
itemParser: parsers.controlNetV2,
|
||||
recaller: recallers.controlNetsV2,
|
||||
itemRecaller: recallers.controlNetV2,
|
||||
validator: validators.controlNetsV2,
|
||||
itemValidator: validators.controlNetV2,
|
||||
renderItemValue: renderControlAdapterValueV2,
|
||||
}),
|
||||
ipAdaptersV2: buildHandlers({
|
||||
getLabel: () => t('common.ipAdapter'),
|
||||
parser: parsers.ipAdaptersV2,
|
||||
itemParser: parsers.ipAdapterV2,
|
||||
recaller: recallers.ipAdaptersV2,
|
||||
itemRecaller: recallers.ipAdapterV2,
|
||||
validator: validators.ipAdaptersV2,
|
||||
itemValidator: validators.ipAdapterV2,
|
||||
renderItemValue: renderControlAdapterValueV2,
|
||||
}),
|
||||
t2iAdaptersV2: buildHandlers({
|
||||
getLabel: () => t('common.t2iAdapter'),
|
||||
parser: parsers.t2iAdaptersV2,
|
||||
itemParser: parsers.t2iAdapterV2,
|
||||
recaller: recallers.t2iAdaptersV2,
|
||||
itemRecaller: recallers.t2iAdapterV2,
|
||||
validator: validators.t2iAdaptersV2,
|
||||
itemValidator: validators.t2iAdapterV2,
|
||||
renderItemValue: renderControlAdapterValueV2,
|
||||
}),
|
||||
} as const;
|
||||
|
||||
export const parseAndRecallPrompts = async (metadata: unknown) => {
|
||||
@ -395,10 +434,25 @@ export const parseAndRecallImageDimensions = async (metadata: unknown) => {
|
||||
}
|
||||
};
|
||||
|
||||
export const parseAndRecallAllMetadata = async (metadata: unknown, skip: (keyof typeof handlers)[] = []) => {
|
||||
// These handlers should be omitted when recalling to control layers
|
||||
const TO_CONTROL_LAYERS_SKIP_KEYS: (keyof typeof handlers)[] = ['controlNets', 'ipAdapters', 't2iAdapters'];
|
||||
// These handlers should be omitted when recalling to the rest of the app
|
||||
const NOT_TO_CONTROL_LAYERS_SKIP_KEYS: (keyof typeof handlers)[] = ['controlNetsV2', 'ipAdaptersV2', 't2iAdaptersV2'];
|
||||
|
||||
export const parseAndRecallAllMetadata = async (
|
||||
metadata: unknown,
|
||||
toControlLayers: boolean,
|
||||
skip: (keyof typeof handlers)[] = []
|
||||
) => {
|
||||
const skipKeys = skip ?? [];
|
||||
if (toControlLayers) {
|
||||
skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS);
|
||||
} else {
|
||||
skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS);
|
||||
}
|
||||
const results = await Promise.allSettled(
|
||||
objectKeys(handlers)
|
||||
.filter((key) => !skip.includes(key))
|
||||
.filter((key) => !skipKeys.includes(key))
|
||||
.map((key) => {
|
||||
const { parse, recall } = handlers[key];
|
||||
return parse(metadata).then((value) => {
|
||||
|
@ -5,13 +5,24 @@ import {
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlAdapters/util/buildControlAdapter';
|
||||
import { buildControlAdapterProcessor } from 'features/controlAdapters/util/buildControlAdapterProcessor';
|
||||
import {
|
||||
CA_PROCESSOR_DATA,
|
||||
imageDTOToImageWithDims,
|
||||
initialControlNetV2,
|
||||
initialIPAdapterV2,
|
||||
initialT2IAdapterV2,
|
||||
isProcessorTypeV2,
|
||||
} from 'features/controlLayers/util/controlAdapters';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
ControlNetConfigMetadata,
|
||||
ControlNetConfigV2Metadata,
|
||||
IPAdapterConfigMetadata,
|
||||
IPAdapterConfigV2Metadata,
|
||||
MetadataParseFunc,
|
||||
T2IAdapterConfigMetadata,
|
||||
T2IAdapterConfigV2Metadata,
|
||||
} from 'features/metadata/types';
|
||||
import { fetchModelConfigWithTypeGuard, getModelKey } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { zControlField, zIPAdapterField, zModelIdentifierField, zT2IAdapterField } from 'features/nodes/types/common';
|
||||
@ -58,7 +69,7 @@ import {
|
||||
isParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import { get, isArray, isString } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { getImageDTO, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
isControlNetModelConfig,
|
||||
@ -428,6 +439,203 @@ const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (
|
||||
}
|
||||
};
|
||||
|
||||
//#region V2/Control Layers
|
||||
const parseControlNetV2: MetadataParseFunc<ControlNetConfigV2Metadata> = async (metadataItem) => {
|
||||
const control_model = await getProperty(metadataItem, 'control_model');
|
||||
const key = await getModelKey(control_model, 'controlnet');
|
||||
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
||||
const image = zControlField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'image'));
|
||||
const processedImage = zControlField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'processed_image'));
|
||||
const control_weight = zControlField.shape.control_weight
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'control_weight'));
|
||||
const begin_step_percent = zControlField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'begin_step_percent'));
|
||||
const end_step_percent = zControlField.shape.end_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'end_step_percent'));
|
||||
const control_mode = zControlField.shape.control_mode
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'control_mode'));
|
||||
|
||||
const id = uuidv4();
|
||||
const defaultPreprocessor = controlNetModel.default_settings?.preprocessor;
|
||||
const processorConfig = isProcessorTypeV2(defaultPreprocessor)
|
||||
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults()
|
||||
: null;
|
||||
const beginEndStepPct: [number, number] = [
|
||||
begin_step_percent ?? initialControlNetV2.beginEndStepPct[0],
|
||||
end_step_percent ?? initialControlNetV2.beginEndStepPct[1],
|
||||
];
|
||||
const imageDTO = image ? await getImageDTO(image.image_name) : null;
|
||||
const processedImageDTO = processedImage ? await getImageDTO(processedImage.image_name) : null;
|
||||
|
||||
const controlNet: ControlNetConfigV2Metadata = {
|
||||
id,
|
||||
type: 'controlnet',
|
||||
model: zModelIdentifierField.parse(controlNetModel),
|
||||
weight: typeof control_weight === 'number' ? control_weight : initialControlNetV2.weight,
|
||||
beginEndStepPct,
|
||||
controlMode: control_mode ?? initialControlNetV2.controlMode,
|
||||
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
|
||||
processedImage: processedImageDTO ? imageDTOToImageWithDims(processedImageDTO) : null,
|
||||
processorConfig,
|
||||
isProcessingImage: false,
|
||||
};
|
||||
|
||||
return controlNet;
|
||||
};
|
||||
|
||||
const parseAllControlNetsV2: MetadataParseFunc<ControlNetConfigV2Metadata[]> = async (metadata) => {
|
||||
try {
|
||||
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray || undefined);
|
||||
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNetV2(cn)));
|
||||
const controlNets = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<ControlNetConfigV2Metadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return controlNets;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
const parseT2IAdapterV2: MetadataParseFunc<T2IAdapterConfigV2Metadata> = async (metadataItem) => {
|
||||
const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model');
|
||||
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
|
||||
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
||||
|
||||
const image = zT2IAdapterField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'image'));
|
||||
const processedImage = zT2IAdapterField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'processed_image'));
|
||||
const weight = zT2IAdapterField.shape.weight
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'weight'));
|
||||
const begin_step_percent = zT2IAdapterField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'begin_step_percent'));
|
||||
const end_step_percent = zT2IAdapterField.shape.end_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'end_step_percent'));
|
||||
|
||||
const id = uuidv4();
|
||||
const defaultPreprocessor = t2iAdapterModel.default_settings?.preprocessor;
|
||||
const processorConfig = isProcessorTypeV2(defaultPreprocessor)
|
||||
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults()
|
||||
: null;
|
||||
const beginEndStepPct: [number, number] = [
|
||||
begin_step_percent ?? initialT2IAdapterV2.beginEndStepPct[0],
|
||||
end_step_percent ?? initialT2IAdapterV2.beginEndStepPct[1],
|
||||
];
|
||||
const imageDTO = image ? await getImageDTO(image.image_name) : null;
|
||||
const processedImageDTO = processedImage ? await getImageDTO(processedImage.image_name) : null;
|
||||
|
||||
const t2iAdapter: T2IAdapterConfigV2Metadata = {
|
||||
id,
|
||||
type: 't2i_adapter',
|
||||
model: zModelIdentifierField.parse(t2iAdapterModel),
|
||||
weight: typeof weight === 'number' ? weight : initialT2IAdapterV2.weight,
|
||||
beginEndStepPct,
|
||||
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
|
||||
processedImage: processedImageDTO ? imageDTOToImageWithDims(processedImageDTO) : null,
|
||||
processorConfig,
|
||||
isProcessingImage: false,
|
||||
};
|
||||
|
||||
return t2iAdapter;
|
||||
};
|
||||
|
||||
const parseAllT2IAdaptersV2: MetadataParseFunc<T2IAdapterConfigV2Metadata[]> = async (metadata) => {
|
||||
try {
|
||||
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapterV2(t2iAdapter)));
|
||||
const t2iAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfigV2Metadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return t2iAdapters;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
const parseIPAdapterV2: MetadataParseFunc<IPAdapterConfigV2Metadata> = async (metadataItem) => {
|
||||
const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model');
|
||||
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
|
||||
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
||||
|
||||
const image = zIPAdapterField.shape.image
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'image'));
|
||||
const weight = zIPAdapterField.shape.weight
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'weight'));
|
||||
const method = zIPAdapterField.shape.method
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'method'));
|
||||
const begin_step_percent = zIPAdapterField.shape.begin_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'begin_step_percent'));
|
||||
const end_step_percent = zIPAdapterField.shape.end_step_percent
|
||||
.nullish()
|
||||
.catch(null)
|
||||
.parse(await getProperty(metadataItem, 'end_step_percent'));
|
||||
|
||||
const id = uuidv4();
|
||||
const beginEndStepPct: [number, number] = [
|
||||
begin_step_percent ?? initialIPAdapterV2.beginEndStepPct[0],
|
||||
end_step_percent ?? initialIPAdapterV2.beginEndStepPct[1],
|
||||
];
|
||||
const imageDTO = image ? await getImageDTO(image.image_name) : null;
|
||||
|
||||
const ipAdapter: IPAdapterConfigV2Metadata = {
|
||||
id,
|
||||
type: 'ip_adapter',
|
||||
model: zModelIdentifierField.parse(ipAdapterModel),
|
||||
weight: typeof weight === 'number' ? weight : initialIPAdapterV2.weight,
|
||||
beginEndStepPct,
|
||||
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
|
||||
clipVisionModel: initialIPAdapterV2.clipVisionModel, // TODO: This needs to be added to the zIPAdapterField...
|
||||
method: method ?? initialIPAdapterV2.method,
|
||||
};
|
||||
|
||||
return ipAdapter;
|
||||
};
|
||||
|
||||
const parseAllIPAdaptersV2: MetadataParseFunc<IPAdapterConfigV2Metadata[]> = async (metadata) => {
|
||||
try {
|
||||
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapterV2(ipAdapter)));
|
||||
const ipAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<IPAdapterConfigV2Metadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return ipAdapters;
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
export const parsers = {
|
||||
createdBy: parseCreatedBy,
|
||||
generationMode: parseGenerationMode,
|
||||
@ -464,4 +672,10 @@ export const parsers = {
|
||||
t2iAdapters: parseAllT2IAdapters,
|
||||
ipAdapter: parseIPAdapter,
|
||||
ipAdapters: parseAllIPAdapters,
|
||||
controlNetV2: parseControlNetV2,
|
||||
controlNetsV2: parseAllControlNetsV2,
|
||||
t2iAdapterV2: parseT2IAdapterV2,
|
||||
t2iAdaptersV2: parseAllT2IAdaptersV2,
|
||||
ipAdapterV2: parseIPAdapterV2,
|
||||
ipAdaptersV2: parseAllIPAdaptersV2,
|
||||
} as const;
|
||||
|
@ -6,7 +6,12 @@ import {
|
||||
t2iAdaptersReset,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import {
|
||||
caLayerAdded,
|
||||
caLayerControlNetsDeleted,
|
||||
caLayerT2IAdaptersDeleted,
|
||||
heightChanged,
|
||||
ipaLayerAdded,
|
||||
ipaLayersDeleted,
|
||||
negativePrompt2Changed,
|
||||
negativePromptChanged,
|
||||
positivePrompt2Changed,
|
||||
@ -18,9 +23,12 @@ import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { loraRecalled, lorasReset } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
ControlNetConfigMetadata,
|
||||
ControlNetConfigV2Metadata,
|
||||
IPAdapterConfigMetadata,
|
||||
IPAdapterConfigV2Metadata,
|
||||
MetadataRecallFunc,
|
||||
T2IAdapterConfigMetadata,
|
||||
T2IAdapterConfigV2Metadata,
|
||||
} from 'features/metadata/types';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
@ -234,6 +242,52 @@ const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapt
|
||||
});
|
||||
};
|
||||
|
||||
//#region V2/Control Layer
|
||||
const recallControlNetV2: MetadataRecallFunc<ControlNetConfigV2Metadata> = (controlNet) => {
|
||||
getStore().dispatch(caLayerAdded(controlNet));
|
||||
};
|
||||
|
||||
const recallControlNetsV2: MetadataRecallFunc<ControlNetConfigV2Metadata[]> = (controlNets) => {
|
||||
const { dispatch } = getStore();
|
||||
dispatch(caLayerControlNetsDeleted());
|
||||
if (!controlNets.length) {
|
||||
return;
|
||||
}
|
||||
controlNets.forEach((controlNet) => {
|
||||
dispatch(caLayerAdded(controlNet));
|
||||
});
|
||||
};
|
||||
|
||||
const recallT2IAdapterV2: MetadataRecallFunc<T2IAdapterConfigV2Metadata> = (t2iAdapter) => {
|
||||
getStore().dispatch(caLayerAdded(t2iAdapter));
|
||||
};
|
||||
|
||||
const recallT2IAdaptersV2: MetadataRecallFunc<T2IAdapterConfigV2Metadata[]> = (t2iAdapters) => {
|
||||
const { dispatch } = getStore();
|
||||
dispatch(caLayerT2IAdaptersDeleted());
|
||||
if (!t2iAdapters.length) {
|
||||
return;
|
||||
}
|
||||
t2iAdapters.forEach((t2iAdapters) => {
|
||||
dispatch(caLayerAdded(t2iAdapters));
|
||||
});
|
||||
};
|
||||
|
||||
const recallIPAdapterV2: MetadataRecallFunc<IPAdapterConfigV2Metadata> = (ipAdapter) => {
|
||||
getStore().dispatch(ipaLayerAdded(ipAdapter));
|
||||
};
|
||||
|
||||
const recallIPAdaptersV2: MetadataRecallFunc<IPAdapterConfigV2Metadata[]> = (ipAdapters) => {
|
||||
const { dispatch } = getStore();
|
||||
dispatch(ipaLayersDeleted());
|
||||
if (!ipAdapters.length) {
|
||||
return;
|
||||
}
|
||||
ipAdapters.forEach((ipAdapter) => {
|
||||
dispatch(ipaLayerAdded(ipAdapter));
|
||||
});
|
||||
};
|
||||
|
||||
export const recallers = {
|
||||
positivePrompt: recallPositivePrompt,
|
||||
negativePrompt: recallNegativePrompt,
|
||||
@ -268,4 +322,10 @@ export const recallers = {
|
||||
t2iAdapter: recallT2IAdapter,
|
||||
ipAdapters: recallIPAdapters,
|
||||
ipAdapter: recallIPAdapter,
|
||||
controlNetV2: recallControlNetV2,
|
||||
controlNetsV2: recallControlNetsV2,
|
||||
t2iAdapterV2: recallT2IAdapterV2,
|
||||
t2iAdaptersV2: recallT2IAdaptersV2,
|
||||
ipAdapterV2: recallIPAdapterV2,
|
||||
ipAdaptersV2: recallIPAdaptersV2,
|
||||
} as const;
|
||||
|
@ -2,9 +2,12 @@ import { getStore } from 'app/store/nanostores/store';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
ControlNetConfigMetadata,
|
||||
ControlNetConfigV2Metadata,
|
||||
IPAdapterConfigMetadata,
|
||||
IPAdapterConfigV2Metadata,
|
||||
MetadataValidateFunc,
|
||||
T2IAdapterConfigMetadata,
|
||||
T2IAdapterConfigV2Metadata,
|
||||
} from 'features/metadata/types';
|
||||
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
@ -108,6 +111,60 @@ const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipA
|
||||
return new Promise((resolve) => resolve(validatedIPAdapters));
|
||||
};
|
||||
|
||||
const validateControlNetV2: MetadataValidateFunc<ControlNetConfigV2Metadata> = (controlNet) => {
|
||||
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(controlNet));
|
||||
};
|
||||
|
||||
const validateControlNetsV2: MetadataValidateFunc<ControlNetConfigV2Metadata[]> = (controlNets) => {
|
||||
const validatedControlNets: ControlNetConfigV2Metadata[] = [];
|
||||
controlNets.forEach((controlNet) => {
|
||||
try {
|
||||
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
|
||||
validatedControlNets.push(controlNet);
|
||||
} catch {
|
||||
// This is a no-op - we want to continue validating the rest of the ControlNets, and an empty list is valid.
|
||||
}
|
||||
});
|
||||
return new Promise((resolve) => resolve(validatedControlNets));
|
||||
};
|
||||
|
||||
const validateT2IAdapterV2: MetadataValidateFunc<T2IAdapterConfigV2Metadata> = (t2iAdapter) => {
|
||||
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(t2iAdapter));
|
||||
};
|
||||
|
||||
const validateT2IAdaptersV2: MetadataValidateFunc<T2IAdapterConfigV2Metadata[]> = (t2iAdapters) => {
|
||||
const validatedT2IAdapters: T2IAdapterConfigV2Metadata[] = [];
|
||||
t2iAdapters.forEach((t2iAdapter) => {
|
||||
try {
|
||||
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
|
||||
validatedT2IAdapters.push(t2iAdapter);
|
||||
} catch {
|
||||
// This is a no-op - we want to continue validating the rest of the T2I Adapters, and an empty list is valid.
|
||||
}
|
||||
});
|
||||
return new Promise((resolve) => resolve(validatedT2IAdapters));
|
||||
};
|
||||
|
||||
const validateIPAdapterV2: MetadataValidateFunc<IPAdapterConfigV2Metadata> = (ipAdapter) => {
|
||||
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(ipAdapter));
|
||||
};
|
||||
|
||||
const validateIPAdaptersV2: MetadataValidateFunc<IPAdapterConfigV2Metadata[]> = (ipAdapters) => {
|
||||
const validatedIPAdapters: IPAdapterConfigV2Metadata[] = [];
|
||||
ipAdapters.forEach((ipAdapter) => {
|
||||
try {
|
||||
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
|
||||
validatedIPAdapters.push(ipAdapter);
|
||||
} catch {
|
||||
// This is a no-op - we want to continue validating the rest of the IP Adapters, and an empty list is valid.
|
||||
}
|
||||
});
|
||||
return new Promise((resolve) => resolve(validatedIPAdapters));
|
||||
};
|
||||
|
||||
export const validators = {
|
||||
refinerModel: validateRefinerModel,
|
||||
vaeModel: validateVAEModel,
|
||||
@ -119,4 +176,10 @@ export const validators = {
|
||||
t2iAdapters: validateT2IAdapters,
|
||||
ipAdapter: validateIPAdapter,
|
||||
ipAdapters: validateIPAdapters,
|
||||
controlNetV2: validateControlNetV2,
|
||||
controlNetsV2: validateControlNetsV2,
|
||||
t2iAdapterV2: validateT2IAdapterV2,
|
||||
t2iAdaptersV2: validateT2IAdaptersV2,
|
||||
ipAdapterV2: validateIPAdapterV2,
|
||||
ipAdaptersV2: validateIPAdaptersV2,
|
||||
} as const;
|
||||
|
@ -43,7 +43,7 @@ export const usePreselectedImage = (selectedImage?: {
|
||||
|
||||
const handleUseAllMetadata = useCallback(() => {
|
||||
if (selectedImageMetadata) {
|
||||
parseAndRecallAllMetadata(selectedImageMetadata);
|
||||
parseAndRecallAllMetadata(selectedImageMetadata, true);
|
||||
}
|
||||
}, [selectedImageMetadata]);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user