mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(ui): tidy model identifier logic
- Move some files around - Use util to extract key and base from model config
This commit is contained in:
parent
3c103c89f3
commit
ab57976e42
@ -6,7 +6,7 @@ import { useControlAdapterModel } from 'features/controlAdapters/hooks/useContro
|
|||||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
import { pick } from 'lodash-es';
|
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
|||||||
dispatch(
|
dispatch(
|
||||||
controlAdapterModelChanged({
|
controlAdapterModelChanged({
|
||||||
id,
|
id,
|
||||||
model: pick(model, 'base', 'key'),
|
model: getModelKeyAndBase(model),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
|
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||||
import { getModelKeyAndBase } from 'features/parameters/util/modelFetchingHelpers';
|
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
export type LoRA = {
|
export type LoRA = {
|
||||||
|
@ -3,16 +3,8 @@ import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
|||||||
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
||||||
import { modelsApi } from 'services/api/endpoints/models';
|
import { modelsApi } from 'services/api/endpoints/models';
|
||||||
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
||||||
import {
|
|
||||||
isControlNetModelConfig,
|
|
||||||
isIPAdapterModelConfig,
|
|
||||||
isLoRAModelConfig,
|
|
||||||
isNonRefinerMainModelConfig,
|
|
||||||
isRefinerMainModelModelConfig,
|
|
||||||
isT2IAdapterModelConfig,
|
|
||||||
isTextualInversionModelConfig,
|
|
||||||
isVAEModelConfig,
|
|
||||||
} from 'services/api/types';
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Raised when a model config is unable to be fetched.
|
* Raised when a model config is unable to be fetched.
|
||||||
@ -101,40 +93,6 @@ export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
|
|||||||
return modelConfig;
|
return modelConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(psyche): Remove these helpers once `useRecallParameters` is removed
|
|
||||||
|
|
||||||
export const fetchMainModelConfig = async (key: string) => {
|
|
||||||
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const fetchRefinerModelConfig = async (key: string) => {
|
|
||||||
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const fetchVAEModelConfig = async (key: string) => {
|
|
||||||
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const fetchLoRAModel = async (key: string) => {
|
|
||||||
return fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const fetchControlNetModel = async (key: string) => {
|
|
||||||
return fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const fetchIPAdapterModel = async (key: string) => {
|
|
||||||
return fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const fetchT2IAdapterModel = async (key: string) => {
|
|
||||||
return fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const fetchTextualInversionModel = async (key: string) => {
|
|
||||||
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Raises an error if the source base model is incompatible with the target base model.
|
* Raises an error if the source base model is incompatible with the target base model.
|
||||||
* @param sourceBase The source base model.
|
* @param sourceBase The source base model.
|
@ -9,6 +9,11 @@ import type { LoRA } from 'features/lora/store/loraSlice';
|
|||||||
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
|
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
|
||||||
import { MetadataParseError } from 'features/metadata/exceptions';
|
import { MetadataParseError } from 'features/metadata/exceptions';
|
||||||
import type { MetadataParseFunc } from 'features/metadata/types';
|
import type { MetadataParseFunc } from 'features/metadata/types';
|
||||||
|
import {
|
||||||
|
fetchModelConfigWithTypeGuard,
|
||||||
|
getModelKey,
|
||||||
|
getModelKeyAndBase,
|
||||||
|
} from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import {
|
import {
|
||||||
zControlField,
|
zControlField,
|
||||||
zIPAdapterField,
|
zIPAdapterField,
|
||||||
@ -54,11 +59,6 @@ import {
|
|||||||
isParameterStrength,
|
isParameterStrength,
|
||||||
isParameterWidth,
|
isParameterWidth,
|
||||||
} from 'features/parameters/types/parameterSchemas';
|
} from 'features/parameters/types/parameterSchemas';
|
||||||
import {
|
|
||||||
fetchModelConfigWithTypeGuard,
|
|
||||||
getModelKey,
|
|
||||||
getModelKeyAndBase,
|
|
||||||
} from 'features/parameters/util/modelFetchingHelpers';
|
|
||||||
import { get, isArray, isString } from 'lodash-es';
|
import { get, isArray, isString } from 'lodash-es';
|
||||||
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
|
@ -2,7 +2,7 @@ import { getStore } from 'app/store/nanostores/store';
|
|||||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||||
import type { MetadataValidateFunc } from 'features/metadata/types';
|
import type { MetadataValidateFunc } from 'features/metadata/types';
|
||||||
import { InvalidModelConfigError } from 'features/parameters/util/modelFetchingHelpers';
|
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { pick } from 'lodash-es';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { ControlNetModelConfig } from 'services/api/types';
|
import type { ControlNetModelConfig } from 'services/api/types';
|
||||||
@ -36,7 +35,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
|||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelEntities: data,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
|
selectedModel: field.value,
|
||||||
isLoading,
|
isLoading,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { pick } from 'lodash-es';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { IPAdapterModelConfig } from 'services/api/types';
|
import type { IPAdapterModelConfig } from 'services/api/types';
|
||||||
@ -36,7 +35,7 @@ const IPAdapterModelFieldInputComponent = (
|
|||||||
const { options, value, onChange } = useGroupedModelCombobox({
|
const { options, value, onChange } = useGroupedModelCombobox({
|
||||||
modelEntities: ipAdapterModels,
|
modelEntities: ipAdapterModels,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
|
selectedModel: field.value,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { pick } from 'lodash-es';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
import type { LoRAModelConfig } from 'services/api/types';
|
||||||
@ -35,7 +34,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
|||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelEntities: data,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
|
selectedModel: field.value,
|
||||||
isLoading,
|
isLoading,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { pick } from 'lodash-es';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { T2IAdapterModelConfig } from 'services/api/types';
|
import type { T2IAdapterModelConfig } from 'services/api/types';
|
||||||
@ -37,7 +36,7 @@ const T2IAdapterModelFieldInputComponent = (
|
|||||||
const { options, value, onChange } = useGroupedModelCombobox({
|
const { options, value, onChange } = useGroupedModelCombobox({
|
||||||
modelEntities: t2iAdapterModels,
|
modelEntities: t2iAdapterModels,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
|
selectedModel: field.value,
|
||||||
});
|
});
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -4,7 +4,6 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
|||||||
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { pick } from 'lodash-es';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
import type { VAEModelConfig } from 'services/api/types';
|
import type { VAEModelConfig } from 'services/api/types';
|
||||||
@ -35,7 +34,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
|||||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelEntities: data,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : null,
|
selectedModel: field.value,
|
||||||
isLoading,
|
isLoading,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -3,8 +3,8 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
|
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||||
import { pick } from 'lodash-es';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||||
@ -30,14 +30,14 @@ const ParamVAEModelSelect = () => {
|
|||||||
);
|
);
|
||||||
const _onChange = useCallback(
|
const _onChange = useCallback(
|
||||||
(vae: VAEModelConfig | null) => {
|
(vae: VAEModelConfig | null) => {
|
||||||
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
|
dispatch(vaeSelected(vae ? getModelKeyAndBase(vae) : null));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||||
modelEntities: data,
|
modelEntities: data,
|
||||||
onChange: _onChange,
|
onChange: _onChange,
|
||||||
selectedModel: vae ? pick(vae, 'key', 'base') : null,
|
selectedModel: vae,
|
||||||
isLoading,
|
isLoading,
|
||||||
getIsDisabled,
|
getIsDisabled,
|
||||||
});
|
});
|
||||||
|
@ -1,27 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { zParameterControlNetModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
import type { ControlNetModelField } from 'services/api/types';
|
|
||||||
|
|
||||||
export const modelIdToControlNetModelParam = (controlNetModelId: string): ControlNetModelField | undefined => {
|
|
||||||
const log = logger('models');
|
|
||||||
const [base_model, _model_type, model_name] = controlNetModelId.split('/');
|
|
||||||
|
|
||||||
const result = zParameterControlNetModel.safeParse({
|
|
||||||
base_model,
|
|
||||||
model_name,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!result.success) {
|
|
||||||
log.error(
|
|
||||||
{
|
|
||||||
controlNetModelId,
|
|
||||||
errors: result.error.format(),
|
|
||||||
},
|
|
||||||
'Failed to parse ControlNet model id'
|
|
||||||
);
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.data;
|
|
||||||
};
|
|
@ -1,27 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { zParameterIPAdapterModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
import type { IPAdapterModelField } from 'services/api/types';
|
|
||||||
|
|
||||||
export const modelIdToIPAdapterModelParam = (ipAdapterModelId: string): IPAdapterModelField | undefined => {
|
|
||||||
const log = logger('models');
|
|
||||||
const [base_model, _model_type, model_name] = ipAdapterModelId.split('/');
|
|
||||||
|
|
||||||
const result = zParameterIPAdapterModel.safeParse({
|
|
||||||
base_model,
|
|
||||||
model_name,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!result.success) {
|
|
||||||
log.error(
|
|
||||||
{
|
|
||||||
ipAdapterModelId,
|
|
||||||
errors: result.error.format(),
|
|
||||||
},
|
|
||||||
'Failed to parse IP-Adapter model id'
|
|
||||||
);
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.data;
|
|
||||||
};
|
|
@ -1,27 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
import { zParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
|
|
||||||
export const modelIdToLoRAModelParam = (loraModelId: string): ParameterLoRAModel | undefined => {
|
|
||||||
const log = logger('models');
|
|
||||||
|
|
||||||
const [base_model, _model_type, model_name] = loraModelId.split('/');
|
|
||||||
|
|
||||||
const result = zParameterLoRAModel.safeParse({
|
|
||||||
base_model,
|
|
||||||
model_name,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!result.success) {
|
|
||||||
log.error(
|
|
||||||
{
|
|
||||||
loraModelId,
|
|
||||||
errors: result.error.format(),
|
|
||||||
},
|
|
||||||
'Failed to parse LoRA model id'
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.data;
|
|
||||||
};
|
|
@ -1,27 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
|
|
||||||
export const modelIdToMainModelParam = (mainModelId: string): ParameterModel | undefined => {
|
|
||||||
const log = logger('models');
|
|
||||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
|
||||||
|
|
||||||
const result = zParameterModel.safeParse({
|
|
||||||
base_model,
|
|
||||||
model_name,
|
|
||||||
model_type,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!result.success) {
|
|
||||||
log.error(
|
|
||||||
{
|
|
||||||
mainModelId,
|
|
||||||
errors: result.error.format(),
|
|
||||||
},
|
|
||||||
'Failed to parse main model id'
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.data;
|
|
||||||
};
|
|
@ -1,27 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { ParameterSDXLRefinerModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
import { zParameterSDXLRefinerModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
|
|
||||||
export const modelIdToSDXLRefinerModelParam = (mainModelId: string): ParameterSDXLRefinerModel | undefined => {
|
|
||||||
const log = logger('models');
|
|
||||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
|
||||||
|
|
||||||
const result = zParameterSDXLRefinerModel.safeParse({
|
|
||||||
base_model,
|
|
||||||
model_name,
|
|
||||||
model_type,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!result.success) {
|
|
||||||
log.error(
|
|
||||||
{
|
|
||||||
mainModelId,
|
|
||||||
errors: result.error.format(),
|
|
||||||
},
|
|
||||||
'Failed to parse main model id'
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.data;
|
|
||||||
};
|
|
@ -1,27 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import { zParameterT2IAdapterModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
import type { T2IAdapterModelField } from 'services/api/types';
|
|
||||||
|
|
||||||
export const modelIdToT2IAdapterModelParam = (t2iAdapterModelId: string): T2IAdapterModelField | undefined => {
|
|
||||||
const log = logger('models');
|
|
||||||
const [base_model, _model_type, model_name] = t2iAdapterModelId.split('/');
|
|
||||||
|
|
||||||
const result = zParameterT2IAdapterModel.safeParse({
|
|
||||||
base_model,
|
|
||||||
model_name,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!result.success) {
|
|
||||||
log.error(
|
|
||||||
{
|
|
||||||
t2iAdapterModelId,
|
|
||||||
errors: result.error.format(),
|
|
||||||
},
|
|
||||||
'Failed to parse T2I-Adapter model id'
|
|
||||||
);
|
|
||||||
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.data;
|
|
||||||
};
|
|
@ -1,26 +0,0 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
|
||||||
import type { ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
import { zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
|
||||||
|
|
||||||
export const modelIdToVAEModelParam = (vaeModelId: string): ParameterVAEModel | undefined => {
|
|
||||||
const log = logger('models');
|
|
||||||
const [base_model, _model_type, model_name] = vaeModelId.split('/');
|
|
||||||
|
|
||||||
const result = zParameterVAEModel.safeParse({
|
|
||||||
base_model,
|
|
||||||
model_name,
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!result.success) {
|
|
||||||
log.error(
|
|
||||||
{
|
|
||||||
vaeModelId,
|
|
||||||
errors: result.error.format(),
|
|
||||||
},
|
|
||||||
'Failed to parse VAE model id'
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.data;
|
|
||||||
};
|
|
@ -2,8 +2,8 @@ import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
|||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||||
|
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||||
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||||
import { pick } from 'lodash-es';
|
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||||
@ -25,7 +25,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
|||||||
dispatch(refinerModelChanged(null));
|
dispatch(refinerModelChanged(null));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch(refinerModelChanged(pick(model, ['key', 'base'])));
|
dispatch(refinerModelChanged(getModelKeyAndBase(model)));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user