feat(ui): create metadata types for control adapters

These are the same as the existing control adapter types, but the model field is non-nullable, simplifying handling of these objects.
This commit is contained in:
psychedelicious 2024-02-26 23:21:53 +11:00
parent 9abfb02bf0
commit d23f2de9d7
8 changed files with 67 additions and 52 deletions

View File

@ -1,6 +1,5 @@
import type { ControlNetConfig } from 'features/controlAdapters/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import type { ControlNetConfigMetadata, MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
@ -9,7 +8,7 @@ type Props = {
};
export const MetadataControlNets = ({ metadata }: Props) => {
const [controlNets, setControlNets] = useState<ControlNetConfig[]>([]);
const [controlNets, setControlNets] = useState<ControlNetConfigMetadata[]>([]);
useEffect(() => {
const parse = async () => {
@ -45,8 +44,8 @@ const MetadataViewControlNet = ({
handlers,
}: {
label: string;
controlNet: ControlNetConfig;
handlers: MetadataHandlers<ControlNetConfig[], ControlNetConfig>;
controlNet: ControlNetConfigMetadata;
handlers: MetadataHandlers<ControlNetConfigMetadata[], ControlNetConfigMetadata>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {

View File

@ -1,6 +1,5 @@
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import type { IPAdapterConfigMetadata, MetadataHandlers } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
@ -9,7 +8,7 @@ type Props = {
};
export const MetadataIPAdapters = ({ metadata }: Props) => {
const [ipAdapters, setIPAdapters] = useState<IPAdapterConfig[]>([]);
const [ipAdapters, setIPAdapters] = useState<IPAdapterConfigMetadata[]>([]);
useEffect(() => {
const parse = async () => {
@ -40,8 +39,8 @@ const MetadataViewIPAdapter = ({
handlers,
}: {
label: string;
ipAdapter: IPAdapterConfig;
handlers: MetadataHandlers<IPAdapterConfig[], IPAdapterConfig>;
ipAdapter: IPAdapterConfigMetadata;
handlers: MetadataHandlers<IPAdapterConfigMetadata[], IPAdapterConfigMetadata>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {

View File

@ -1,6 +1,5 @@
import type { T2IAdapterConfig } from 'features/controlAdapters/store/types';
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
import type { MetadataHandlers } from 'features/metadata/types';
import type { MetadataHandlers, T2IAdapterConfigMetadata } from 'features/metadata/types';
import { handlers } from 'features/metadata/util/handlers';
import { useCallback, useEffect, useMemo, useState } from 'react';
@ -9,7 +8,7 @@ type Props = {
};
export const MetadataT2IAdapters = ({ metadata }: Props) => {
const [t2iAdapters, setT2IAdapters] = useState<T2IAdapterConfig[]>([]);
const [t2iAdapters, setT2IAdapters] = useState<T2IAdapterConfigMetadata[]>([]);
useEffect(() => {
const parse = async () => {
@ -45,8 +44,8 @@ const MetadataViewT2IAdapter = ({
handlers,
}: {
label: string;
t2iAdapter: T2IAdapterConfig;
handlers: MetadataHandlers<T2IAdapterConfig[], T2IAdapterConfig>;
t2iAdapter: T2IAdapterConfigMetadata;
handlers: MetadataHandlers<T2IAdapterConfigMetadata[], T2IAdapterConfigMetadata>;
}) => {
const onRecall = useCallback(() => {
if (!handlers.recallItem) {

View File

@ -1,3 +1,6 @@
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { O } from 'ts-toolbelt';
/**
* Renders a value of type T as a React node.
*/
@ -134,3 +137,11 @@ export type BuildMetadataHandlersArg<TValue, TItem> = {
export type BuildMetadataHandlers = <TValue, TItem>(
arg: BuildMetadataHandlersArg<TValue, TItem>
) => MetadataHandlers<TValue, TItem>;
export type ControlNetConfigMetadata = O.NonNullable<ControlNetConfig, 'model'>;
export type T2IAdapterConfigMetadata = O.NonNullable<T2IAdapterConfig, 'model'>;
export type IPAdapterConfigMetadata = O.NonNullable<IPAdapterConfig, 'model'>;
export type AnyControlAdapterConfigMetadata =
| ControlNetConfigMetadata
| T2IAdapterConfigMetadata
| IPAdapterConfigMetadata;

View File

@ -1,8 +1,8 @@
import { objectKeys } from 'common/util/objectKeys';
import { toast } from 'common/util/toast';
import type { ControlAdapterConfig } from 'features/controlAdapters/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type {
AnyControlAdapterConfigMetadata,
BuildMetadataHandlers,
MetadataGetLabelFunc,
MetadataHandlers,
@ -35,12 +35,12 @@ const renderLoRAValue: MetadataRenderValueFunc<LoRA> = async (value) => {
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
}
};
const renderControlAdapterValue: MetadataRenderValueFunc<ControlAdapterConfig> = async (value) => {
const renderControlAdapterValue: MetadataRenderValueFunc<AnyControlAdapterConfigMetadata> = async (value) => {
try {
const modelConfig = await fetchModelConfig(value.model?.key ?? 'none');
const modelConfig = await fetchModelConfig(value.model.key ?? 'none');
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
} catch {
return `${value.model?.key} (${value.model?.base.toUpperCase()}) - ${value.weight}`;
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
}
};

View File

@ -1,5 +1,4 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import {
initialControlNet,
initialIPAdapter,
@ -8,7 +7,7 @@ import {
import type { LoRA } from 'features/lora/store/loraSlice';
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
import { MetadataParseError } from 'features/metadata/exceptions';
import type { MetadataParseFunc } from 'features/metadata/types';
import type { ControlNetConfigMetadata, IPAdapterConfigMetadata, MetadataParseFunc, T2IAdapterConfigMetadata } from 'features/metadata/types';
import {
fetchModelConfigWithTypeGuard,
getModelKey,
@ -213,7 +212,7 @@ const parseAllLoRAs: MetadataParseFunc<LoRA[]> = async (metadata) => {
return loras;
};
const parseControlNet: MetadataParseFunc<ControlNetConfig> = async (metadataItem) => {
const parseControlNet: MetadataParseFunc<ControlNetConfigMetadata> = async (metadataItem) => {
const control_model = await getProperty(metadataItem, 'control_model');
const key = await getModelKey(control_model, 'controlnet');
const controlNetModel = await fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
@ -243,7 +242,7 @@ const parseControlNet: MetadataParseFunc<ControlNetConfig> = async (metadataItem
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const controlNet: ControlNetConfig = {
const controlNet: ControlNetConfigMetadata = {
type: 'controlnet',
isEnabled: true,
model: zModelIdentifierWithBase.parse(controlNetModel),
@ -263,16 +262,16 @@ const parseControlNet: MetadataParseFunc<ControlNetConfig> = async (metadataItem
return controlNet;
};
const parseAllControlNets: MetadataParseFunc<ControlNetConfig[]> = async (metadata) => {
const parseAllControlNets: MetadataParseFunc<ControlNetConfigMetadata[]> = async (metadata) => {
const controlNetsRaw = await getProperty(metadata, 'controlnets', isArray);
const parseResults = await Promise.allSettled(controlNetsRaw.map((cn) => parseControlNet(cn)));
const controlNets = parseResults
.filter((result): result is PromiseFulfilledResult<ControlNetConfig> => result.status === 'fulfilled')
.filter((result): result is PromiseFulfilledResult<ControlNetConfigMetadata> => result.status === 'fulfilled')
.map((result) => result.value);
return controlNets;
};
const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfig> = async (metadataItem) => {
const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfigMetadata> = async (metadataItem) => {
const t2i_adapter_model = await getProperty(metadataItem, 't2i_adapter_model');
const key = await getModelKey(t2i_adapter_model, 't2i_adapter');
const t2iAdapterModel = await fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
@ -295,7 +294,7 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfig> = async (metadataItem
const processorType = 'none';
const processorNode = CONTROLNET_PROCESSORS.none.default;
const t2iAdapter: T2IAdapterConfig = {
const t2iAdapter: T2IAdapterConfigMetadata = {
type: 't2i_adapter',
isEnabled: true,
model: zModelIdentifierWithBase.parse(t2iAdapterModel),
@ -314,16 +313,16 @@ const parseT2IAdapter: MetadataParseFunc<T2IAdapterConfig> = async (metadataItem
return t2iAdapter;
};
const parseAllT2IAdapters: MetadataParseFunc<T2IAdapterConfig[]> = async (metadata) => {
const parseAllT2IAdapters: MetadataParseFunc<T2IAdapterConfigMetadata[]> = async (metadata) => {
const t2iAdaptersRaw = await getProperty(metadata, 't2iAdapters', isArray);
const parseResults = await Promise.allSettled(t2iAdaptersRaw.map((t2iAdapter) => parseT2IAdapter(t2iAdapter)));
const t2iAdapters = parseResults
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfig> => result.status === 'fulfilled')
.filter((result): result is PromiseFulfilledResult<T2IAdapterConfigMetadata> => result.status === 'fulfilled')
.map((result) => result.value);
return t2iAdapters;
};
const parseIPAdapter: MetadataParseFunc<IPAdapterConfig> = async (metadataItem) => {
const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metadataItem) => {
const ip_adapter_model = await getProperty(metadataItem, 'ip_adapter_model');
const key = await getModelKey(ip_adapter_model, 'ip_adapter');
const ipAdapterModel = await fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
@ -339,7 +338,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfig> = async (metadataItem)
.catch(null)
.parse(getProperty(metadataItem, 'end_step_percent'));
const ipAdapter: IPAdapterConfig = {
const ipAdapter: IPAdapterConfigMetadata = {
id: uuidv4(),
type: 'ip_adapter',
isEnabled: true,
@ -353,11 +352,11 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfig> = async (metadataItem)
return ipAdapter;
};
const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfig[]> = async (metadata) => {
const parseAllIPAdapters: MetadataParseFunc<IPAdapterConfigMetadata[]> = async (metadata) => {
const ipAdaptersRaw = await getProperty(metadata, 'ipAdapters', isArray);
const parseResults = await Promise.allSettled(ipAdaptersRaw.map((ipAdapter) => parseIPAdapter(ipAdapter)));
const ipAdapters = parseResults
.filter((result): result is PromiseFulfilledResult<IPAdapterConfig> => result.status === 'fulfilled')
.filter((result): result is PromiseFulfilledResult<IPAdapterConfigMetadata> => result.status === 'fulfilled')
.map((result) => result.value);
return ipAdapters;
};

View File

@ -1,10 +1,14 @@
import { getStore } from 'app/store/nanostores/store';
import { controlAdapterRecalled } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/hrfSlice';
import type { LoRA } from 'features/lora/store/loraSlice';
import { loraRecalled } from 'features/lora/store/loraSlice';
import type { MetadataRecallFunc } from 'features/metadata/types';
import type {
ControlNetConfigMetadata,
IPAdapterConfigMetadata,
MetadataRecallFunc,
T2IAdapterConfigMetadata,
} from 'features/metadata/types';
import { modelSelected } from 'features/parameters/store/actions';
import {
heightRecalled,
@ -168,33 +172,33 @@ const recallAllLoRAs: MetadataRecallFunc<LoRA[]> = (loras) => {
});
};
const recallControlNet: MetadataRecallFunc<ControlNetConfig> = (controlNet) => {
const recallControlNet: MetadataRecallFunc<ControlNetConfigMetadata> = (controlNet) => {
getStore().dispatch(controlAdapterRecalled(controlNet));
};
const recallControlNets: MetadataRecallFunc<ControlNetConfig[]> = (controlNets) => {
const recallControlNets: MetadataRecallFunc<ControlNetConfigMetadata[]> = (controlNets) => {
const { dispatch } = getStore();
controlNets.forEach((controlNet) => {
dispatch(controlAdapterRecalled(controlNet));
});
};
const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfig> = (t2iAdapter) => {
const recallT2IAdapter: MetadataRecallFunc<T2IAdapterConfigMetadata> = (t2iAdapter) => {
getStore().dispatch(controlAdapterRecalled(t2iAdapter));
};
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfig[]> = (t2iAdapters) => {
const recallT2IAdapters: MetadataRecallFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
const { dispatch } = getStore();
t2iAdapters.forEach((t2iAdapter) => {
dispatch(controlAdapterRecalled(t2iAdapter));
});
};
const recallIPAdapter: MetadataRecallFunc<IPAdapterConfig> = (ipAdapter) => {
const recallIPAdapter: MetadataRecallFunc<IPAdapterConfigMetadata> = (ipAdapter) => {
getStore().dispatch(controlAdapterRecalled(ipAdapter));
};
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfig[]> = (ipAdapters) => {
const recallIPAdapters: MetadataRecallFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
const { dispatch } = getStore();
ipAdapters.forEach((ipAdapter) => {
dispatch(controlAdapterRecalled(ipAdapter));

View File

@ -1,7 +1,11 @@
import { getStore } from 'app/store/nanostores/store';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type { MetadataValidateFunc } from 'features/metadata/types';
import type {
ControlNetConfigMetadata,
IPAdapterConfigMetadata,
MetadataValidateFunc,
T2IAdapterConfigMetadata,
} from 'features/metadata/types';
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import type { BaseModelType } from 'services/api/types';
@ -50,13 +54,13 @@ const validateLoRAs: MetadataValidateFunc<LoRA[]> = (loras) => {
return new Promise((resolve) => resolve(validatedLoRAs));
};
const validateControlNet: MetadataValidateFunc<ControlNetConfig> = (controlNet) => {
const validateControlNet: MetadataValidateFunc<ControlNetConfigMetadata> = (controlNet) => {
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
return new Promise((resolve) => resolve(controlNet));
};
const validateControlNets: MetadataValidateFunc<ControlNetConfig[]> = (controlNets) => {
const validatedControlNets: ControlNetConfig[] = [];
const validateControlNets: MetadataValidateFunc<ControlNetConfigMetadata[]> = (controlNets) => {
const validatedControlNets: ControlNetConfigMetadata[] = [];
controlNets.forEach((controlNet) => {
try {
validateBaseCompatibility(controlNet.model?.base, 'ControlNet incompatible with currently-selected model');
@ -68,13 +72,13 @@ const validateControlNets: MetadataValidateFunc<ControlNetConfig[]> = (controlNe
return new Promise((resolve) => resolve(validatedControlNets));
};
const validateT2IAdapter: MetadataValidateFunc<T2IAdapterConfig> = (t2iAdapter) => {
const validateT2IAdapter: MetadataValidateFunc<T2IAdapterConfigMetadata> = (t2iAdapter) => {
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
return new Promise((resolve) => resolve(t2iAdapter));
};
const validateT2IAdapters: MetadataValidateFunc<T2IAdapterConfig[]> = (t2iAdapters) => {
const validatedT2IAdapters: T2IAdapterConfig[] = [];
const validateT2IAdapters: MetadataValidateFunc<T2IAdapterConfigMetadata[]> = (t2iAdapters) => {
const validatedT2IAdapters: T2IAdapterConfigMetadata[] = [];
t2iAdapters.forEach((t2iAdapter) => {
try {
validateBaseCompatibility(t2iAdapter.model?.base, 'T2I Adapter incompatible with currently-selected model');
@ -86,13 +90,13 @@ const validateT2IAdapters: MetadataValidateFunc<T2IAdapterConfig[]> = (t2iAdapte
return new Promise((resolve) => resolve(validatedT2IAdapters));
};
const validateIPAdapter: MetadataValidateFunc<IPAdapterConfig> = (ipAdapter) => {
const validateIPAdapter: MetadataValidateFunc<IPAdapterConfigMetadata> = (ipAdapter) => {
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');
return new Promise((resolve) => resolve(ipAdapter));
};
const validateIPAdapters: MetadataValidateFunc<IPAdapterConfig[]> = (ipAdapters) => {
const validatedIPAdapters: IPAdapterConfig[] = [];
const validateIPAdapters: MetadataValidateFunc<IPAdapterConfigMetadata[]> = (ipAdapters) => {
const validatedIPAdapters: IPAdapterConfigMetadata[] = [];
ipAdapters.forEach((ipAdapter) => {
try {
validateBaseCompatibility(ipAdapter.model?.base, 'IP Adapter incompatible with currently-selected model');