mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): create metadata types for control adapters
These are the same as the existing control adapter types, but the model field is non-nullable, simplifying handling of these objects.
This commit is contained in:
parent
a3b11c04cb
commit
3efd9465eb
@ -1,6 +1,5 @@
|
||||
import type { ControlNetConfig } from 'features/controlAdapters/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import type { ControlNetConfigMetadata, MetadataHandlers } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
@ -9,7 +8,7 @@ type Props = {
|
||||
};
|
||||
|
||||
export const MetadataControlNets = ({ metadata }: Props) => {
|
||||
const [controlNets, setControlNets] = useState<ControlNetConfig[]>([]);
|
||||
const [controlNets, setControlNets] = useState<ControlNetConfigMetadata[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
@ -45,8 +44,8 @@ const MetadataViewControlNet = ({
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
controlNet: ControlNetConfig;
|
||||
handlers: MetadataHandlers<ControlNetConfig[], ControlNetConfig>;
|
||||
controlNet: ControlNetConfigMetadata;
|
||||
handlers: MetadataHandlers<ControlNetConfigMetadata[], ControlNetConfigMetadata>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
|
@ -1,6 +1,5 @@
|
||||
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import type { IPAdapterConfigMetadata, MetadataHandlers } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
@ -9,7 +8,7 @@ type Props = {
|
||||
};
|
||||
|
||||
export const MetadataIPAdapters = ({ metadata }: Props) => {
|
||||
const [ipAdapters, setIPAdapters] = useState<IPAdapterConfig[]>([]);
|
||||
const [ipAdapters, setIPAdapters] = useState<IPAdapterConfigMetadata[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
@ -40,8 +39,8 @@ const MetadataViewIPAdapter = ({
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
ipAdapter: IPAdapterConfig;
|
||||
handlers: MetadataHandlers<IPAdapterConfig[], IPAdapterConfig>;
|
||||
ipAdapter: IPAdapterConfigMetadata;
|
||||
handlers: MetadataHandlers<IPAdapterConfigMetadata[], IPAdapterConfigMetadata>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
|
@ -1,6 +1,5 @@
|
||||
import type { T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||
import type { MetadataHandlers } from 'features/metadata/types';
|
||||
import type { MetadataHandlers, T2IAdapterConfigMetadata } from 'features/metadata/types';
|
||||
import { handlers } from 'features/metadata/util/handlers';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
|
||||
@ -9,7 +8,7 @@ type Props = {
|
||||
};
|
||||
|
||||
export const MetadataT2IAdapters = ({ metadata }: Props) => {
|
||||
const [t2iAdapters, setT2IAdapters] = useState<T2IAdapterConfig[]>([]);
|
||||
const [t2iAdapters, setT2IAdapters] = useState<T2IAdapterConfigMetadata[]>([]);
|
||||
|
||||
useEffect(() => {
|
||||
const parse = async () => {
|
||||
@ -45,8 +44,8 @@ const MetadataViewT2IAdapter = ({
|
||||
handlers,
|
||||
}: {
|
||||
label: string;
|
||||
t2iAdapter: T2IAdapterConfig;
|
||||
handlers: MetadataHandlers<T2IAdapterConfig[], T2IAdapterConfig>;
|
||||
t2iAdapter: T2IAdapterConfigMetadata;
|
||||
handlers: MetadataHandlers<T2IAdapterConfigMetadata[], T2IAdapterConfigMetadata>;
|
||||
}) => {
|
||||
const onRecall = useCallback(() => {
|
||||
if (!handlers.recallItem) {
|
||||
|
@ -1,3 +1,6 @@
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
/**
|
||||
* Renders a value of type T as a React node.
|
||||
*/
|
||||
@ -134,3 +137,11 @@ export type BuildMetadataHandlersArg<TValue, TItem> = {
|
||||
export type BuildMetadataHandlers = <TValue, TItem>(
|
||||
arg: BuildMetadataHandlersArg<TValue, TItem>
|
||||
) => MetadataHandlers<TValue, TItem>;
|
||||
|
||||
export type ControlNetConfigMetadata = O.NonNullable<ControlNetConfig, 'model'>;
|
||||
export type T2IAdapterConfigMetadata = O.NonNullable<T2IAdapterConfig, 'model'>;
|
||||
export type IPAdapterConfigMetadata = O.NonNullable<IPAdapterConfig, 'model'>;
|
||||
export type AnyControlAdapterConfigMetadata =
|
||||
| ControlNetConfigMetadata
|
||||
| T2IAdapterConfigMetadata
|
||||
| IPAdapterConfigMetadata;
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { objectKeys } from 'common/util/objectKeys';
|
||||
import { toast } from 'common/util/toast';
|
||||
import type { ControlAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
AnyControlAdapterConfigMetadata,
|
||||
BuildMetadataHandlers,
|
||||
MetadataGetLabelFunc,
|
||||
MetadataHandlers,
|
||||
@ -35,12 +35,12 @@ const renderLoRAValue: MetadataRenderValueFunc<LoRA> = async (value) => {
|
||||
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
|
||||
}
|
||||
};
|
||||
const renderControlAdapterValue: MetadataRenderValueFunc<ControlAdapterConfig> = async (value) => {
|
||||
const renderControlAdapterValue: MetadataRenderValueFunc<AnyControlAdapterConfigMetadata> = async (value) => {
|
||||
try {
|
||||
const modelConfig = await fetchModelConfig(value.model?.key ?? 'none');
|
||||
const modelConfig = await fetchModelConfig(value.model.key ?? 'none');
|
||||
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
|
||||
} catch {
|
||||
return `${value.model?.key} (${value.model?.base.toUpperCase()}) - ${value.weight}`;
|
||||
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import {
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
@ -8,7 +7,7 @@ import {
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
|
||||
import { MetadataParseError } from 'features/metadata/exceptions';
|
||||
import type { MetadataParseFunc } from 'features/metadata/types';
|
||||
import type { ControlNetConfigMetadata, IPAdapterConfigMetadata, MetadataParseFunc, T2IAdapterConfigMetadata } from 'features/metadata/types';
|
||||
import {
|
||||
fetchModelConfigWithTypeGuard,
|
||||
getModelKey,
|
||||
@ -213,7 +212,7 @@ const parseAllLoRAs: MetadataParseFunc<LoRA[]> = async (metadata) => {
|
||||
return loras;
|
||||
};
|
||||
|
||||
const parseControlNet: MetadataParseFunc<ControlNetConfig> = async (metadataItem) => {
|
||||
const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (metadataItem) => {
|
||||
const control_model = await getProperty(metadataItem, 'control_model');
|
||||
const key = await getModelKey(control_model, 'controlnet');
|
||||
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
||||
@ -243,7 +242,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfig> = async (metadataItem
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const controlNet: ControlNetConfig = {
|
||||
const controlNet: ControlNetConfigMetadata = {
|
||||
type: 'controlnet',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(controlNetModel),
|
||||
@ -263,16 +262,16 @@ const parseControlNet: MetadataParseFunc<ControlNetConfig> = async (metadataItem
|
||||
return controlNet;
|
||||
};
|
||||
|
||||
const parseAllControlNets: MetadataParseFunc<ControlNetConfig[]> = async (metadata) => {
|
||||
const parseAllControlNets: MetadataParseFunc<ControlNetConfigMetadata[]> = async (metadata) => {
|
||||
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray);
|
||||
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn)));
|
||||
const controlNets = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<ControlNetConfig> => result.status === 'fulfilled')
|
||||
.filter((result): result is PromiseFulfilledResult<ControlNetConfigMetadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return controlNets;
|
||||
};
|
||||
|
||||
const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfig> = async (metadataItem) => {
|
||||
const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (metadataItem) => {
|
||||
const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model');
|
||||
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
|
||||
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
||||
@ -295,7 +294,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfig> = async (metadataItem
|
||||
const processorType = 'none';
|
||||
const processorNode = CONTROLNET_PROCESSORS.none.default;
|
||||
|
||||
const t2iAdapter: T2IAdapterConfig = {
|
||||
const t2iAdapter: T2IAdapterConfigMetadata = {
|
||||
type: 't2i_adapter',
|
||||
isEnabled: true,
|
||||
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
|
||||
@ -314,16 +313,16 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfig> = async (metadataItem
|
||||
return t2iAdapter;
|
||||
};
|
||||
|
||||
const parseAllT2IAdapters: MetadataParseFunc<T2IAdapterConfig[]> = async (metadata) => {
|
||||
const parseAllT2IAdapters: MetadataParseFunc<T2IAdapterConfigMetadata[]> = async (metadata) => {
|
||||
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter)));
|
||||
const t2iAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfig> => result.status === 'fulfilled')
|
||||
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfigMetadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return t2iAdapters;
|
||||
};
|
||||
|
||||
const parseIPAdapter: MetadataParseFunc<IPAdapterConfig> = async (metadataItem) => {
|
||||
const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metadataItem) => {
|
||||
const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model');
|
||||
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
|
||||
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
||||
@ -339,7 +338,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfig> = async (metadataItem)
|
||||
.catch(null)
|
||||
.parse(getProperty(metadataItem, 'end_step_percent'));
|
||||
|
||||
const ipAdapter: IPAdapterConfig = {
|
||||
const ipAdapter: IPAdapterConfigMetadata = {
|
||||
id: uuidv4(),
|
||||
type: 'ip_adapter',
|
||||
isEnabled: true,
|
||||
@ -353,11 +352,11 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfig> = async (metadataItem)
|
||||
return ipAdapter;
|
||||
};
|
||||
|
||||
const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfig[]> = async (metadata) => {
|
||||
const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (metadata) => {
|
||||
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
|
||||
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter)));
|
||||
const ipAdapters = parseResults
|
||||
.filter((result): result is PromiseFulfilledResult<IPAdapterConfig> => result.status === 'fulfilled')
|
||||
.filter((result): result is PromiseFulfilledResult<IPAdapterConfigMetadata> => result.status === 'fulfilled')
|
||||
.map((result) => result.value);
|
||||
return ipAdapters;
|
||||
};
|
||||
|
@ -1,10 +1,14 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { loraRecalled } from 'features/lora/store/loraSlice';
|
||||
import type { MetadataRecallFunc } from 'features/metadata/types';
|
||||
import type {
|
||||
ControlNetConfigMetadata,
|
||||
IPAdapterConfigMetadata,
|
||||
MetadataRecallFunc,
|
||||
T2IAdapterConfigMetadata,
|
||||
} from 'features/metadata/types';
|
||||
import { modelSelected } from 'features/parameters/store/actions';
|
||||
import {
|
||||
heightRecalled,
|
||||
@ -168,33 +172,33 @@ const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
|
||||
});
|
||||
};
|
||||
|
||||
const recallControlNet: MetadataRecallFunc<ControlNetConfig> = (controlNet) => {
|
||||
const recallControlNet: MetadataRecallFunc<ControlNetConfigMetadata> = (controlNet) => {
|
||||
getStore().dispatch(controlAdapterRecalled(controlNet));
|
||||
};
|
||||
|
||||
const recallControlNets: MetadataRecallFunc<ControlNetConfig[]> = (controlNets) => {
|
||||
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
|
||||
const { dispatch } = getStore();
|
||||
controlNets.forEach((controlNet) => {
|
||||
dispatch(controlAdapterRecalled(controlNet));
|
||||
});
|
||||
};
|
||||
|
||||
const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfig> = (t2iAdapter) => {
|
||||
const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfigMetadata> = (t2iAdapter) => {
|
||||
getStore().dispatch(controlAdapterRecalled(t2iAdapter));
|
||||
};
|
||||
|
||||
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfig[]> = (t2iAdapters) => {
|
||||
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
|
||||
const { dispatch } = getStore();
|
||||
t2iAdapters.forEach((t2iAdapter) => {
|
||||
dispatch(controlAdapterRecalled(t2iAdapter));
|
||||
});
|
||||
};
|
||||
|
||||
const recallIPAdapter: MetadataRecallFunc<IPAdapterConfig> = (ipAdapter) => {
|
||||
const recallIPAdapter: MetadataRecallFunc<IPAdapterConfigMetadata> = (ipAdapter) => {
|
||||
getStore().dispatch(controlAdapterRecalled(ipAdapter));
|
||||
};
|
||||
|
||||
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfig[]> = (ipAdapters) => {
|
||||
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
|
||||
const { dispatch } = getStore();
|
||||
ipAdapters.forEach((ipAdapter) => {
|
||||
dispatch(controlAdapterRecalled(ipAdapter));
|
||||
|
@ -1,7 +1,11 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type { MetadataValidateFunc } from 'features/metadata/types';
|
||||
import type {
|
||||
ControlNetConfigMetadata,
|
||||
IPAdapterConfigMetadata,
|
||||
MetadataValidateFunc,
|
||||
T2IAdapterConfigMetadata,
|
||||
} from 'features/metadata/types';
|
||||
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { BaseModelType } from 'services/api/types';
|
||||
@ -50,13 +54,13 @@ const validateLoRAs: MetadataValidateFunc<LoRA[]> = (loras) => {
|
||||
return new Promise((resolve) => resolve(validatedLoRAs));
|
||||
};
|
||||
|
||||
const validateControlNet: MetadataValidateFunc<ControlNetConfig> = (controlNet) => {
|
||||
const validateControlNet: MetadataValidateFunc<ControlNetConfigMetadata> = (controlNet) => {
|
||||
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(controlNet));
|
||||
};
|
||||
|
||||
const validateControlNets: MetadataValidateFunc<ControlNetConfig[]> = (controlNets) => {
|
||||
const validatedControlNets: ControlNetConfig[] = [];
|
||||
const validateControlNets: MetadataValidateFunc<ControlNetConfigMetadata[]> = (controlNets) => {
|
||||
const validatedControlNets: ControlNetConfigMetadata[] = [];
|
||||
controlNets.forEach((controlNet) => {
|
||||
try {
|
||||
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
|
||||
@ -68,13 +72,13 @@ const validateControlNets: MetadataValidateFunc<ControlNetConfig[]> = (controlNe
|
||||
return new Promise((resolve) => resolve(validatedControlNets));
|
||||
};
|
||||
|
||||
const validateT2IAdapter: MetadataValidateFunc<T2IAdapterConfig> = (t2iAdapter) => {
|
||||
const validateT2IAdapter: MetadataValidateFunc<T2IAdapterConfigMetadata> = (t2iAdapter) => {
|
||||
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(t2iAdapter));
|
||||
};
|
||||
|
||||
const validateT2IAdapters: MetadataValidateFunc<T2IAdapterConfig[]> = (t2iAdapters) => {
|
||||
const validatedT2IAdapters: T2IAdapterConfig[] = [];
|
||||
const validateT2IAdapters: MetadataValidateFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
|
||||
const validatedT2IAdapters: T2IAdapterConfigMetadata[] = [];
|
||||
t2iAdapters.forEach((t2iAdapter) => {
|
||||
try {
|
||||
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
|
||||
@ -86,13 +90,13 @@ const validateT2IAdapters: MetadataValidateFunc<T2IAdapterConfig[]> = (t2iAdapte
|
||||
return new Promise((resolve) => resolve(validatedT2IAdapters));
|
||||
};
|
||||
|
||||
const validateIPAdapter: MetadataValidateFunc<IPAdapterConfig> = (ipAdapter) => {
|
||||
const validateIPAdapter: MetadataValidateFunc<IPAdapterConfigMetadata> = (ipAdapter) => {
|
||||
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
|
||||
return new Promise((resolve) => resolve(ipAdapter));
|
||||
};
|
||||
|
||||
const validateIPAdapters: MetadataValidateFunc<IPAdapterConfig[]> = (ipAdapters) => {
|
||||
const validatedIPAdapters: IPAdapterConfig[] = [];
|
||||
const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
|
||||
const validatedIPAdapters: IPAdapterConfigMetadata[] = [];
|
||||
ipAdapters.forEach((ipAdapter) => {
|
||||
try {
|
||||
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
|
||||
|
Loading…
Reference in New Issue
Block a user