mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): model metadata handlers use model identifiers, not configs
Model metadata includes the main model, VAE and refiner model. These used full model configs, as returned by the server, as their metadata type. LoRA and control adapter metadata only use the metadata identifier. This created a difference in handling. After parsing a model/vae/refiner, we have its name and can display it. But for LoRAs and control adapters, we only have the model key and must query for the full model config to get the name. This change makes main model/vae/refiner metadata only have the model key, like LoRAs and control adapters. The render function is now async so fetching can occur within it. All metadata fields with models now only contain the identifier, and fetch the model name to render their values.
This commit is contained in:
parent
7b4ef5926d
commit
9abfb02bf0
@ -1,4 +1,3 @@
|
|||||||
import { Text } from '@invoke-ai/ui-library';
|
|
||||||
import type { ControlNetConfig } from 'features/controlAdapters/store/types';
|
import type { ControlNetConfig } from 'features/controlAdapters/store/types';
|
||||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||||
import type { MetadataHandlers } from 'features/metadata/types';
|
import type { MetadataHandlers } from 'features/metadata/types';
|
||||||
@ -56,11 +55,18 @@ const MetadataViewControlNet = ({
|
|||||||
handlers.recallItem(controlNet, true);
|
handlers.recallItem(controlNet, true);
|
||||||
}, [handlers, controlNet]);
|
}, [handlers, controlNet]);
|
||||||
|
|
||||||
const renderedValue = useMemo(() => {
|
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
const _renderValue = async () => {
|
||||||
if (!handlers.renderItemValue) {
|
if (!handlers.renderItemValue) {
|
||||||
return null;
|
setRenderedValue(null);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return <Text>{handlers.renderItemValue(controlNet)}</Text>;
|
const rendered = await handlers.renderItemValue(controlNet);
|
||||||
|
setRenderedValue(rendered);
|
||||||
|
};
|
||||||
|
|
||||||
|
_renderValue();
|
||||||
}, [handlers, controlNet]);
|
}, [handlers, controlNet]);
|
||||||
|
|
||||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import { Text } from '@invoke-ai/ui-library';
|
|
||||||
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
|
import type { IPAdapterConfig } from 'features/controlAdapters/store/types';
|
||||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||||
import type { MetadataHandlers } from 'features/metadata/types';
|
import type { MetadataHandlers } from 'features/metadata/types';
|
||||||
@ -51,11 +50,18 @@ const MetadataViewIPAdapter = ({
|
|||||||
handlers.recallItem(ipAdapter, true);
|
handlers.recallItem(ipAdapter, true);
|
||||||
}, [handlers, ipAdapter]);
|
}, [handlers, ipAdapter]);
|
||||||
|
|
||||||
const renderedValue = useMemo(() => {
|
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
const _renderValue = async () => {
|
||||||
if (!handlers.renderItemValue) {
|
if (!handlers.renderItemValue) {
|
||||||
return null;
|
setRenderedValue(null);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return <Text>{handlers.renderItemValue(ipAdapter)}</Text>;
|
const rendered = await handlers.renderItemValue(ipAdapter);
|
||||||
|
setRenderedValue(rendered);
|
||||||
|
};
|
||||||
|
|
||||||
|
_renderValue();
|
||||||
}, [handlers, ipAdapter]);
|
}, [handlers, ipAdapter]);
|
||||||
|
|
||||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import { Text } from '@invoke-ai/ui-library';
|
|
||||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||||
import type { MetadataHandlers } from 'features/metadata/types';
|
import type { MetadataHandlers } from 'features/metadata/types';
|
||||||
@ -51,11 +50,18 @@ const MetadataViewLoRA = ({
|
|||||||
handlers.recallItem(lora, true);
|
handlers.recallItem(lora, true);
|
||||||
}, [handlers, lora]);
|
}, [handlers, lora]);
|
||||||
|
|
||||||
const renderedValue = useMemo(() => {
|
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
const _renderValue = async () => {
|
||||||
if (!handlers.renderItemValue) {
|
if (!handlers.renderItemValue) {
|
||||||
return null;
|
setRenderedValue(null);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return <Text>{handlers.renderItemValue(lora)}</Text>;
|
const rendered = await handlers.renderItemValue(lora);
|
||||||
|
setRenderedValue(rendered);
|
||||||
|
};
|
||||||
|
|
||||||
|
_renderValue();
|
||||||
}, [handlers, lora]);
|
}, [handlers, lora]);
|
||||||
|
|
||||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import { Text } from '@invoke-ai/ui-library';
|
|
||||||
import type { T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
import type { T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||||
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
import { MetadataItemView } from 'features/metadata/components/MetadataItemView';
|
||||||
import type { MetadataHandlers } from 'features/metadata/types';
|
import type { MetadataHandlers } from 'features/metadata/types';
|
||||||
@ -56,11 +55,18 @@ const MetadataViewT2IAdapter = ({
|
|||||||
handlers.recallItem(t2iAdapter, true);
|
handlers.recallItem(t2iAdapter, true);
|
||||||
}, [handlers, t2iAdapter]);
|
}, [handlers, t2iAdapter]);
|
||||||
|
|
||||||
const renderedValue = useMemo(() => {
|
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
const _renderValue = async () => {
|
||||||
if (!handlers.renderItemValue) {
|
if (!handlers.renderItemValue) {
|
||||||
return null;
|
setRenderedValue(null);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return <Text>{handlers.renderItemValue(t2iAdapter)}</Text>;
|
const rendered = await handlers.renderItemValue(t2iAdapter);
|
||||||
|
setRenderedValue(rendered);
|
||||||
|
};
|
||||||
|
|
||||||
|
_renderValue();
|
||||||
}, [handlers, t2iAdapter]);
|
}, [handlers, t2iAdapter]);
|
||||||
|
|
||||||
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
return <MetadataItemView label={label} isDisabled={false} onRecall={onRecall} renderedValue={renderedValue} />;
|
||||||
|
@ -3,10 +3,14 @@ import type { MetadataHandlers } from 'features/metadata/types';
|
|||||||
import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers';
|
import { MetadataParseFailedToken, MetadataParsePendingToken } from 'features/metadata/util/parsers';
|
||||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||||
|
|
||||||
|
const pendingRenderedValue = <Text>Loading</Text>;
|
||||||
|
const failedRenderedValue = <Text>Parsing Failed</Text>;
|
||||||
|
|
||||||
export const useMetadataItem = <T,>(metadata: unknown, handlers: MetadataHandlers<T>) => {
|
export const useMetadataItem = <T,>(metadata: unknown, handlers: MetadataHandlers<T>) => {
|
||||||
const [value, setValue] = useState<T | typeof MetadataParsePendingToken | typeof MetadataParseFailedToken>(
|
const [value, setValue] = useState<T | typeof MetadataParsePendingToken | typeof MetadataParseFailedToken>(
|
||||||
MetadataParsePendingToken
|
MetadataParsePendingToken
|
||||||
);
|
);
|
||||||
|
const [renderedValue, setRenderedValue] = useState<React.ReactNode>(pendingRenderedValue);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const _parse = async () => {
|
const _parse = async () => {
|
||||||
@ -24,20 +28,27 @@ export const useMetadataItem = <T,>(metadata: unknown, handlers: MetadataHandler
|
|||||||
|
|
||||||
const label = useMemo(() => handlers.getLabel(), [handlers]);
|
const label = useMemo(() => handlers.getLabel(), [handlers]);
|
||||||
|
|
||||||
const renderedValue = useMemo(() => {
|
useEffect(() => {
|
||||||
|
const _renderValue = async () => {
|
||||||
if (value === MetadataParsePendingToken) {
|
if (value === MetadataParsePendingToken) {
|
||||||
return <Text>Loading</Text>;
|
setRenderedValue(pendingRenderedValue);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
if (value === MetadataParseFailedToken) {
|
if (value === MetadataParseFailedToken) {
|
||||||
return <Text>Parsing Failed</Text>;
|
setRenderedValue(failedRenderedValue);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const rendered = handlers.renderValue(value);
|
const rendered = await handlers.renderValue(value);
|
||||||
|
|
||||||
if (typeof rendered === 'string') {
|
if (typeof rendered === 'string') {
|
||||||
return <Text>{rendered}</Text>;
|
setRenderedValue(<Text>{rendered}</Text>);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
return rendered;
|
setRenderedValue(rendered);
|
||||||
|
};
|
||||||
|
|
||||||
|
_renderValue();
|
||||||
}, [handlers, value]);
|
}, [handlers, value]);
|
||||||
|
|
||||||
const onRecall = useCallback(() => {
|
const onRecall = useCallback(() => {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Renders a value of type T as a React node.
|
* Renders a value of type T as a React node.
|
||||||
*/
|
*/
|
||||||
export type MetadataRenderValueFunc<T> = (value: T) => React.ReactNode;
|
export type MetadataRenderValueFunc<T> = (value: T) => Promise<React.ReactNode>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the label of the current metadata item as a string.
|
* Gets the label of the current metadata item as a string.
|
||||||
|
@ -11,18 +11,38 @@ import type {
|
|||||||
MetadataRenderValueFunc,
|
MetadataRenderValueFunc,
|
||||||
MetadataValidateFunc,
|
MetadataValidateFunc,
|
||||||
} from 'features/metadata/types';
|
} from 'features/metadata/types';
|
||||||
|
import { fetchModelConfig } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { validators } from 'features/metadata/util/validators';
|
import { validators } from 'features/metadata/util/validators';
|
||||||
|
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import { parsers } from './parsers';
|
import { parsers } from './parsers';
|
||||||
import { recallers } from './recallers';
|
import { recallers } from './recallers';
|
||||||
|
|
||||||
const renderModelConfigValue: MetadataRenderValueFunc<AnyModelConfig> = (value) =>
|
const renderModelConfigValue: MetadataRenderValueFunc<ModelIdentifierWithBase> = async (value) => {
|
||||||
`${value.name} (${value.base.toUpperCase()}, ${value.key})`;
|
try {
|
||||||
const renderLoRAValue: MetadataRenderValueFunc<LoRA> = (value) => `${value.model.key} (${value.weight})`;
|
const modelConfig = await fetchModelConfig(value.key);
|
||||||
const renderControlAdapterValue: MetadataRenderValueFunc<ControlAdapterConfig> = (value) =>
|
return `${modelConfig.name} (${modelConfig.base.toUpperCase()})`;
|
||||||
`${value.model?.key} (${value.weight})`;
|
} catch {
|
||||||
|
return `${value.key} (${value.base.toUpperCase()})`;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const renderLoRAValue: MetadataRenderValueFunc<LoRA> = async (value) => {
|
||||||
|
try {
|
||||||
|
const modelConfig = await fetchModelConfig(value.model.key);
|
||||||
|
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
|
||||||
|
} catch {
|
||||||
|
return `${value.model.key} (${value.model.base.toUpperCase()}) - ${value.weight}`;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
const renderControlAdapterValue: MetadataRenderValueFunc<ControlAdapterConfig> = async (value) => {
|
||||||
|
try {
|
||||||
|
const modelConfig = await fetchModelConfig(value.model?.key ?? 'none');
|
||||||
|
return `${modelConfig.name} (${modelConfig.base.toUpperCase()}) - ${value.weight}`;
|
||||||
|
} catch {
|
||||||
|
return `${value.model?.key} (${value.model?.base.toUpperCase()}) - ${value.weight}`;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const parameterSetToast = (parameter: string, description?: string) => {
|
const parameterSetToast = (parameter: string, description?: string) => {
|
||||||
toast({
|
toast({
|
||||||
@ -130,6 +150,8 @@ const buildRecallItem =
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const resolveToString = (value: unknown) => new Promise<React.ReactNode>((resolve) => resolve(String(value)));
|
||||||
|
|
||||||
const buildHandlers: BuildMetadataHandlers = ({
|
const buildHandlers: BuildMetadataHandlers = ({
|
||||||
getLabel,
|
getLabel,
|
||||||
parser,
|
parser,
|
||||||
@ -146,8 +168,8 @@ const buildHandlers: BuildMetadataHandlers = ({
|
|||||||
recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined,
|
recall: recaller ? buildRecall({ recaller, validator, getLabel }) : undefined,
|
||||||
recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined,
|
recallItem: itemRecaller ? buildRecallItem({ itemRecaller, itemValidator, getLabel }) : undefined,
|
||||||
getLabel,
|
getLabel,
|
||||||
renderValue: renderValue ?? String,
|
renderValue: renderValue ?? resolveToString,
|
||||||
renderItemValue: renderItemValue ?? String,
|
renderItemValue: renderItemValue ?? resolveToString,
|
||||||
});
|
});
|
||||||
|
|
||||||
export const handlers = {
|
export const handlers = {
|
||||||
|
@ -12,7 +12,6 @@ import type { MetadataParseFunc } from 'features/metadata/types';
|
|||||||
import {
|
import {
|
||||||
fetchModelConfigWithTypeGuard,
|
fetchModelConfigWithTypeGuard,
|
||||||
getModelKey,
|
getModelKey,
|
||||||
getModelKeyAndBase,
|
|
||||||
} from 'features/metadata/util/modelFetchingHelpers';
|
} from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import {
|
import {
|
||||||
zControlField,
|
zControlField,
|
||||||
@ -26,17 +25,20 @@ import type {
|
|||||||
ParameterHeight,
|
ParameterHeight,
|
||||||
ParameterHRFEnabled,
|
ParameterHRFEnabled,
|
||||||
ParameterHRFMethod,
|
ParameterHRFMethod,
|
||||||
|
ParameterModel,
|
||||||
ParameterNegativePrompt,
|
ParameterNegativePrompt,
|
||||||
ParameterNegativeStylePromptSDXL,
|
ParameterNegativeStylePromptSDXL,
|
||||||
ParameterPositivePrompt,
|
ParameterPositivePrompt,
|
||||||
ParameterPositiveStylePromptSDXL,
|
ParameterPositiveStylePromptSDXL,
|
||||||
ParameterScheduler,
|
ParameterScheduler,
|
||||||
|
ParameterSDXLRefinerModel,
|
||||||
ParameterSDXLRefinerNegativeAestheticScore,
|
ParameterSDXLRefinerNegativeAestheticScore,
|
||||||
ParameterSDXLRefinerPositiveAestheticScore,
|
ParameterSDXLRefinerPositiveAestheticScore,
|
||||||
ParameterSDXLRefinerStart,
|
ParameterSDXLRefinerStart,
|
||||||
ParameterSeed,
|
ParameterSeed,
|
||||||
ParameterSteps,
|
ParameterSteps,
|
||||||
ParameterStrength,
|
ParameterStrength,
|
||||||
|
ParameterVAEModel,
|
||||||
ParameterWidth,
|
ParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import {
|
import {
|
||||||
@ -60,7 +62,6 @@ import {
|
|||||||
isParameterWidth,
|
isParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import { get, isArray, isString } from 'lodash-es';
|
import { get, isArray, isString } from 'lodash-es';
|
||||||
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
|
||||||
import {
|
import {
|
||||||
isControlNetModelConfig,
|
isControlNetModelConfig,
|
||||||
isIPAdapterModelConfig,
|
isIPAdapterModelConfig,
|
||||||
@ -163,25 +164,28 @@ const parseRefinerNegativeAestheticScore: MetadataParseFunc<ParameterSDXLRefiner
|
|||||||
const parseRefinerStart: MetadataParseFunc<ParameterSDXLRefinerStart> = (metadata) =>
|
const parseRefinerStart: MetadataParseFunc<ParameterSDXLRefinerStart> = (metadata) =>
|
||||||
getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart);
|
getProperty(metadata, 'refiner_start', isParameterSDXLRefinerStart);
|
||||||
|
|
||||||
const parseMainModel: MetadataParseFunc<NonRefinerMainModelConfig> = async (metadata) => {
|
const parseMainModel: MetadataParseFunc<ParameterModel> = async (metadata) => {
|
||||||
const model = await getProperty(metadata, 'model', undefined);
|
const model = await getProperty(metadata, 'model', undefined);
|
||||||
const key = await getModelKey(model, 'main');
|
const key = await getModelKey(model, 'main');
|
||||||
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
const mainModelConfig = await fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||||
return mainModelConfig;
|
const modelIdentifier = zModelIdentifierWithBase.parse(mainModelConfig);
|
||||||
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
const parseRefinerModel: MetadataParseFunc<RefinerMainModelConfig> = async (metadata) => {
|
const parseRefinerModel: MetadataParseFunc<ParameterSDXLRefinerModel> = async (metadata) => {
|
||||||
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
const refiner_model = await getProperty(metadata, 'refiner_model', undefined);
|
||||||
const key = await getModelKey(refiner_model, 'main');
|
const key = await getModelKey(refiner_model, 'main');
|
||||||
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
const refinerModelConfig = await fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||||
return refinerModelConfig;
|
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModelConfig);
|
||||||
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
const parseVAEModel: MetadataParseFunc<VAEModelConfig> = async (metadata) => {
|
const parseVAEModel: MetadataParseFunc<ParameterVAEModel> = async (metadata) => {
|
||||||
const vae = await getProperty(metadata, 'vae', undefined);
|
const vae = await getProperty(metadata, 'vae', undefined);
|
||||||
const key = await getModelKey(vae, 'vae');
|
const key = await getModelKey(vae, 'vae');
|
||||||
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
const vaeModelConfig = await fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||||
return vaeModelConfig;
|
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModelConfig);
|
||||||
|
return modelIdentifier;
|
||||||
};
|
};
|
||||||
|
|
||||||
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
||||||
@ -194,7 +198,7 @@ const parseLoRA: MetadataParseFunc<LoRA> = async (metadataItem) => {
|
|||||||
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
const loraModelConfig = await fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
model: getModelKeyAndBase(loraModelConfig),
|
model: zModelIdentifierWithBase.parse(loraModelConfig),
|
||||||
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
weight: isParameterLoRAWeight(weight) ? weight : defaultLoRAConfig.weight,
|
||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
};
|
};
|
||||||
|
@ -5,7 +5,6 @@ import { setHrfEnabled, setHrfMethod, setHrfStrength } from 'features/hrf/store/
|
|||||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||||
import { loraRecalled } from 'features/lora/store/loraSlice';
|
import { loraRecalled } from 'features/lora/store/loraSlice';
|
||||||
import type { MetadataRecallFunc } from 'features/metadata/types';
|
import type { MetadataRecallFunc } from 'features/metadata/types';
|
||||||
import { zModelIdentifierWithBase } from 'features/nodes/types/common';
|
|
||||||
import { modelSelected } from 'features/parameters/store/actions';
|
import { modelSelected } from 'features/parameters/store/actions';
|
||||||
import {
|
import {
|
||||||
heightRecalled,
|
heightRecalled,
|
||||||
@ -26,17 +25,20 @@ import type {
|
|||||||
ParameterHeight,
|
ParameterHeight,
|
||||||
ParameterHRFEnabled,
|
ParameterHRFEnabled,
|
||||||
ParameterHRFMethod,
|
ParameterHRFMethod,
|
||||||
|
ParameterModel,
|
||||||
ParameterNegativePrompt,
|
ParameterNegativePrompt,
|
||||||
ParameterNegativeStylePromptSDXL,
|
ParameterNegativeStylePromptSDXL,
|
||||||
ParameterPositivePrompt,
|
ParameterPositivePrompt,
|
||||||
ParameterPositiveStylePromptSDXL,
|
ParameterPositiveStylePromptSDXL,
|
||||||
ParameterScheduler,
|
ParameterScheduler,
|
||||||
|
ParameterSDXLRefinerModel,
|
||||||
ParameterSDXLRefinerNegativeAestheticScore,
|
ParameterSDXLRefinerNegativeAestheticScore,
|
||||||
ParameterSDXLRefinerPositiveAestheticScore,
|
ParameterSDXLRefinerPositiveAestheticScore,
|
||||||
ParameterSDXLRefinerStart,
|
ParameterSDXLRefinerStart,
|
||||||
ParameterSeed,
|
ParameterSeed,
|
||||||
ParameterSteps,
|
ParameterSteps,
|
||||||
ParameterStrength,
|
ParameterStrength,
|
||||||
|
ParameterVAEModel,
|
||||||
ParameterWidth,
|
ParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import {
|
import {
|
||||||
@ -50,7 +52,6 @@ import {
|
|||||||
setRefinerStart,
|
setRefinerStart,
|
||||||
setRefinerSteps,
|
setRefinerSteps,
|
||||||
} from 'features/sdxl/store/sdxlSlice';
|
} from 'features/sdxl/store/sdxlSlice';
|
||||||
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
|
||||||
getStore().dispatch(setPositivePrompt(positivePrompt));
|
getStore().dispatch(setPositivePrompt(positivePrompt));
|
||||||
@ -140,23 +141,20 @@ const recallRefinerStart: MetadataRecallFunc<ParameterSDXLRefinerStart> = (refin
|
|||||||
getStore().dispatch(setRefinerStart(refinerStart));
|
getStore().dispatch(setRefinerStart(refinerStart));
|
||||||
};
|
};
|
||||||
|
|
||||||
const recallModel: MetadataRecallFunc<NonRefinerMainModelConfig> = (model) => {
|
const recallModel: MetadataRecallFunc<ParameterModel> = (model) => {
|
||||||
const modelIdentifier = zModelIdentifierWithBase.parse(model);
|
getStore().dispatch(modelSelected(model));
|
||||||
getStore().dispatch(modelSelected(modelIdentifier));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const recallRefinerModel: MetadataRecallFunc<RefinerMainModelConfig> = (refinerModel) => {
|
const recallRefinerModel: MetadataRecallFunc<ParameterSDXLRefinerModel> = (refinerModel) => {
|
||||||
const modelIdentifier = zModelIdentifierWithBase.parse(refinerModel);
|
getStore().dispatch(refinerModelChanged(refinerModel));
|
||||||
getStore().dispatch(refinerModelChanged(modelIdentifier));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const recallVAE: MetadataRecallFunc<VAEModelConfig | null | undefined> = (vaeModel) => {
|
const recallVAE: MetadataRecallFunc<ParameterVAEModel | null | undefined> = (vaeModel) => {
|
||||||
if (!vaeModel) {
|
if (!vaeModel) {
|
||||||
getStore().dispatch(vaeSelected(null));
|
getStore().dispatch(vaeSelected(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const modelIdentifier = zModelIdentifierWithBase.parse(vaeModel);
|
getStore().dispatch(vaeSelected(vaeModel));
|
||||||
getStore().dispatch(vaeSelected(modelIdentifier));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
const recallLoRA: MetadataRecallFunc<LoRA> = (lora) => {
|
||||||
|
@ -3,7 +3,8 @@ import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'featur
|
|||||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||||
import type { MetadataValidateFunc } from 'features/metadata/types';
|
import type { MetadataValidateFunc } from 'features/metadata/types';
|
||||||
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
import type { ParameterSDXLRefinerModel, ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import type { BaseModelType } from 'services/api/types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks the given base model type against the currently-selected model's base type and throws an error if they are
|
* Checks the given base model type against the currently-selected model's base type and throws an error if they are
|
||||||
@ -21,12 +22,12 @@ const validateBaseCompatibility = (base?: BaseModelType, message?: string) => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const validateRefinerModel: MetadataValidateFunc<RefinerMainModelConfig> = (refinerModel) => {
|
const validateRefinerModel: MetadataValidateFunc<ParameterSDXLRefinerModel> = (refinerModel) => {
|
||||||
validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model');
|
validateBaseCompatibility('sdxl', 'Refiner incompatible with currently-selected model');
|
||||||
return new Promise((resolve) => resolve(refinerModel));
|
return new Promise((resolve) => resolve(refinerModel));
|
||||||
};
|
};
|
||||||
|
|
||||||
const validateVAEModel: MetadataValidateFunc<VAEModelConfig> = (vaeModel) => {
|
const validateVAEModel: MetadataValidateFunc<ParameterVAEModel> = (vaeModel) => {
|
||||||
validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model');
|
validateBaseCompatibility(vaeModel.base, 'VAE incompatible with currently-selected model');
|
||||||
return new Promise((resolve) => resolve(vaeModel));
|
return new Promise((resolve) => resolve(vaeModel));
|
||||||
};
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user