mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(ui): tidy model identifier logic
- Move some files around - Use util to extract key and base from model config
This commit is contained in:
parent
3c103c89f3
commit
ab57976e42
@ -6,7 +6,7 @@ import { useControlAdapterModel } from 'features/controlAdapters/hooks/useContro
|
||||
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { pick } from 'lodash-es';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
@ -31,7 +31,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
dispatch(
|
||||
controlAdapterModelChanged({
|
||||
id,
|
||||
model: pick(model, 'base', 'key'),
|
||||
model: getModelKeyAndBase(model),
|
||||
})
|
||||
);
|
||||
},
|
||||
|
@ -1,8 +1,8 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { getModelKeyAndBase } from 'features/parameters/util/modelFetchingHelpers';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
export type LoRA = {
|
||||
|
@ -3,16 +3,8 @@ import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
|
||||
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
||||
import {
|
||||
isControlNetModelConfig,
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isTextualInversionModelConfig,
|
||||
isVAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Raised when a model config is unable to be fetched.
|
||||
@ -101,40 +93,6 @@ export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
|
||||
return modelConfig;
|
||||
};
|
||||
|
||||
// TODO(psyche): Remove these helpers once `useRecallParameters` is removed
|
||||
|
||||
export const fetchMainModelConfig = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
|
||||
};
|
||||
|
||||
export const fetchRefinerModelConfig = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
|
||||
};
|
||||
|
||||
export const fetchVAEModelConfig = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
|
||||
};
|
||||
|
||||
export const fetchLoRAModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
|
||||
};
|
||||
|
||||
export const fetchControlNetModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
|
||||
};
|
||||
|
||||
export const fetchIPAdapterModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
|
||||
};
|
||||
|
||||
export const fetchT2IAdapterModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
|
||||
};
|
||||
|
||||
export const fetchTextualInversionModel = async (key: string) => {
|
||||
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
|
||||
};
|
||||
|
||||
/**
|
||||
* Raises an error if the source base model is incompatible with the target base model.
|
||||
* @param sourceBase The source base model.
|
@ -9,6 +9,11 @@ import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
|
||||
import { MetadataParseError } from 'features/metadata/exceptions';
|
||||
import type { MetadataParseFunc } from 'features/metadata/types';
|
||||
import {
|
||||
fetchModelConfigWithTypeGuard,
|
||||
getModelKey,
|
||||
getModelKeyAndBase,
|
||||
} from 'features/metadata/util/modelFetchingHelpers';
|
||||
import {
|
||||
zControlField,
|
||||
zIPAdapterField,
|
||||
@ -54,11 +59,6 @@ import {
|
||||
isParameterStrength,
|
||||
isParameterWidth,
|
||||
} from 'features/parameters/types/parameterSchemas';
|
||||
import {
|
||||
fetchModelConfigWithTypeGuard,
|
||||
getModelKey,
|
||||
getModelKeyAndBase,
|
||||
} from 'features/parameters/util/modelFetchingHelpers';
|
||||
import { get, isArray, isString } from 'lodash-es';
|
||||
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
import {
|
||||
|
@ -2,7 +2,7 @@ import { getStore } from 'app/store/nanostores/store';
|
||||
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type { MetadataValidateFunc } from 'features/metadata/types';
|
||||
import { InvalidModelConfigError } from 'features/parameters/util/modelFetchingHelpers';
|
||||
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
|
||||
|
||||
/**
|
||||
|
@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { ControlNetModelConfig } from 'services/api/types';
|
||||
@ -36,7 +35,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
|
@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { IPAdapterModelConfig } from 'services/api/types';
|
||||
@ -36,7 +35,7 @@ const IPAdapterModelFieldInputComponent = (
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: ipAdapterModels,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
|
@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
@ -35,7 +34,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
|
@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { T2IAdapterModelConfig } from 'services/api/types';
|
||||
@ -37,7 +36,7 @@ const T2IAdapterModelFieldInputComponent = (
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelEntities: t2iAdapterModels,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
|
@ -4,7 +4,6 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { VAEModelConfig } from 'services/api/types';
|
||||
@ -35,7 +34,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value ? pick(field.value, ['key', 'base']) : null,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
|
@ -3,8 +3,8 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
|
||||
@ -30,14 +30,14 @@ const ParamVAEModelSelect = () => {
|
||||
);
|
||||
const _onChange = useCallback(
|
||||
(vae: VAEModelConfig | null) => {
|
||||
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
|
||||
dispatch(vaeSelected(vae ? getModelKeyAndBase(vae) : null));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelEntities: data,
|
||||
onChange: _onChange,
|
||||
selectedModel: vae ? pick(vae, 'key', 'base') : null,
|
||||
selectedModel: vae,
|
||||
isLoading,
|
||||
getIsDisabled,
|
||||
});
|
||||
|
@ -1,27 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { zParameterControlNetModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { ControlNetModelField } from 'services/api/types';
|
||||
|
||||
export const modelIdToControlNetModelParam = (controlNetModelId: string): ControlNetModelField | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, _model_type, model_name] = controlNetModelId.split('/');
|
||||
|
||||
const result = zParameterControlNetModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
controlNetModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse ControlNet model id'
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -1,27 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { zParameterIPAdapterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { IPAdapterModelField } from 'services/api/types';
|
||||
|
||||
export const modelIdToIPAdapterModelParam = (ipAdapterModelId: string): IPAdapterModelField | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, _model_type, model_name] = ipAdapterModelId.split('/');
|
||||
|
||||
const result = zParameterIPAdapterModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
ipAdapterModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse IP-Adapter model id'
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -1,27 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { zParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export const modelIdToLoRAModelParam = (loraModelId: string): ParameterLoRAModel | undefined => {
|
||||
const log = logger('models');
|
||||
|
||||
const [base_model, _model_type, model_name] = loraModelId.split('/');
|
||||
|
||||
const result = zParameterLoRAModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
loraModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse LoRA model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -1,27 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export const modelIdToMainModelParam = (mainModelId: string): ParameterModel | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||
|
||||
const result = zParameterModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
model_type,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
mainModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse main model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -1,27 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { ParameterSDXLRefinerModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { zParameterSDXLRefinerModel } from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export const modelIdToSDXLRefinerModelParam = (mainModelId: string): ParameterSDXLRefinerModel | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, model_type, model_name] = mainModelId.split('/');
|
||||
|
||||
const result = zParameterSDXLRefinerModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
model_type,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
mainModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse main model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -1,27 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { zParameterT2IAdapterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import type { T2IAdapterModelField } from 'services/api/types';
|
||||
|
||||
export const modelIdToT2IAdapterModelParam = (t2iAdapterModelId: string): T2IAdapterModelField | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, _model_type, model_name] = t2iAdapterModelId.split('/');
|
||||
|
||||
const result = zParameterT2IAdapterModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
t2iAdapterModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse T2I-Adapter model id'
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -1,26 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export const modelIdToVAEModelParam = (vaeModelId: string): ParameterVAEModel | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, _model_type, model_name] = vaeModelId.split('/');
|
||||
|
||||
const result = zParameterVAEModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
vaeModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse VAE model id'
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -2,8 +2,8 @@ import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
|
||||
import { pick } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { REFINER_BASE_MODELS } from 'services/api/constants';
|
||||
@ -25,7 +25,7 @@ const ParamSDXLRefinerModelSelect = () => {
|
||||
dispatch(refinerModelChanged(null));
|
||||
return;
|
||||
}
|
||||
dispatch(refinerModelChanged(pick(model, ['key', 'base'])));
|
||||
dispatch(refinerModelChanged(getModelKeyAndBase(model)));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user