tidy(ui): tidy model identifier logic

- Move some files around
- Use util to extract key and base from model config
This commit is contained in:
psychedelicious 2024-02-26 13:40:52 +11:00 committed by Kent Keirsey
parent 3c103c89f3
commit ab57976e42
19 changed files with 21 additions and 256 deletions

View File

@ -6,7 +6,7 @@ import { useControlAdapterModel } from 'features/controlAdapters/hooks/useContro
import { useControlAdapterModelQuery } from 'features/controlAdapters/hooks/useControlAdapterModelQuery';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { pick } from 'lodash-es';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { memo, useCallback, useMemo } from 'react';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
@ -31,7 +31,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
dispatch(
controlAdapterModelChanged({
id,
model: pick(model, 'base', 'key'),
model: getModelKeyAndBase(model),
})
);
},

View File

@ -1,8 +1,8 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { getModelKeyAndBase } from 'features/parameters/util/modelFetchingHelpers';
import type { LoRAModelConfig } from 'services/api/types';
export type LoRA = {

View File

@ -3,16 +3,8 @@ import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isT2IAdapterModelConfig,
isTextualInversionModelConfig,
isVAEModelConfig,
} from 'services/api/types';
/**
* Raised when a model config is unable to be fetched.
@ -101,40 +93,6 @@ export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
return modelConfig;
};
// TODO(psyche): Remove these helpers once `useRecallParameters` is removed
export const fetchMainModelConfig = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isNonRefinerMainModelConfig);
};
export const fetchRefinerModelConfig = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isRefinerMainModelModelConfig);
};
export const fetchVAEModelConfig = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isVAEModelConfig);
};
export const fetchLoRAModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isLoRAModelConfig);
};
export const fetchControlNetModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isControlNetModelConfig);
};
export const fetchIPAdapterModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isIPAdapterModelConfig);
};
export const fetchT2IAdapterModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isT2IAdapterModelConfig);
};
export const fetchTextualInversionModel = async (key: string) => {
return fetchModelConfigWithTypeGuard(key, isTextualInversionModelConfig);
};
/**
* Raises an error if the source base model is incompatible with the target base model.
* @param sourceBase The source base model.

View File

@ -9,6 +9,11 @@ import type { LoRA } from 'features/lora/store/loraSlice';
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
import { MetadataParseError } from 'features/metadata/exceptions';
import type { MetadataParseFunc } from 'features/metadata/types';
import {
fetchModelConfigWithTypeGuard,
getModelKey,
getModelKeyAndBase,
} from 'features/metadata/util/modelFetchingHelpers';
import {
zControlField,
zIPAdapterField,
@ -54,11 +59,6 @@ import {
isParameterStrength,
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
fetchModelConfigWithTypeGuard,
getModelKey,
getModelKeyAndBase,
} from 'features/parameters/util/modelFetchingHelpers';
import { get, isArray, isString } from 'lodash-es';
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
import {

View File

@ -2,7 +2,7 @@ import { getStore } from 'app/store/nanostores/store';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type { MetadataValidateFunc } from 'features/metadata/types';
import { InvalidModelConfigError } from 'features/parameters/util/modelFetchingHelpers';
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
/**

View File

@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { ControlNetModelFieldInputInstance, ControlNetModelFieldInputTemplate } from 'features/nodes/types/field';
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
import type { ControlNetModelConfig } from 'services/api/types';
@ -36,7 +35,7 @@ const ControlNetModelFieldInputComponent = (props: Props) => {
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
onChange: _onChange,
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
selectedModel: field.value,
isLoading,
});

View File

@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { IPAdapterModelFieldInputInstance, IPAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
import type { IPAdapterModelConfig } from 'services/api/types';
@ -36,7 +35,7 @@ const IPAdapterModelFieldInputComponent = (
const { options, value, onChange } = useGroupedModelCombobox({
modelEntities: ipAdapterModels,
onChange: _onChange,
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
selectedModel: field.value,
});
return (

View File

@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { LoRAModelFieldInputInstance, LoRAModelFieldInputTemplate } from 'features/nodes/types/field';
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
import type { LoRAModelConfig } from 'services/api/types';
@ -35,7 +34,7 @@ const LoRAModelFieldInputComponent = (props: Props) => {
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
onChange: _onChange,
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
selectedModel: field.value,
isLoading,
});

View File

@ -3,7 +3,6 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { T2IAdapterModelFieldInputInstance, T2IAdapterModelFieldInputTemplate } from 'features/nodes/types/field';
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
import type { T2IAdapterModelConfig } from 'services/api/types';
@ -37,7 +36,7 @@ const T2IAdapterModelFieldInputComponent = (
const { options, value, onChange } = useGroupedModelCombobox({
modelEntities: t2iAdapterModels,
onChange: _onChange,
selectedModel: field.value ? pick(field.value, ['key', 'base']) : undefined,
selectedModel: field.value,
});
return (

View File

@ -4,7 +4,6 @@ import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
import type { VAEModelConfig } from 'services/api/types';
@ -35,7 +34,7 @@ const VAEModelFieldInputComponent = (props: Props) => {
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
onChange: _onChange,
selectedModel: field.value ? pick(field.value, ['key', 'base']) : null,
selectedModel: field.value,
isLoading,
});

View File

@ -3,8 +3,8 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { selectGenerationSlice, vaeSelected } from 'features/parameters/store/generationSlice';
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetVaeModelsQuery } from 'services/api/endpoints/models';
@ -30,14 +30,14 @@ const ParamVAEModelSelect = () => {
);
const _onChange = useCallback(
(vae: VAEModelConfig | null) => {
dispatch(vaeSelected(vae ? pick(vae, 'key', 'base') : null));
dispatch(vaeSelected(vae ? getModelKeyAndBase(vae) : null));
},
[dispatch]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
modelEntities: data,
onChange: _onChange,
selectedModel: vae ? pick(vae, 'key', 'base') : null,
selectedModel: vae,
isLoading,
getIsDisabled,
});

View File

@ -1,27 +0,0 @@
import { logger } from 'app/logging/logger';
import { zParameterControlNetModel } from 'features/parameters/types/parameterSchemas';
import type { ControlNetModelField } from 'services/api/types';
export const modelIdToControlNetModelParam = (controlNetModelId: string): ControlNetModelField | undefined => {
const log = logger('models');
const [base_model, _model_type, model_name] = controlNetModelId.split('/');
const result = zParameterControlNetModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
log.error(
{
controlNetModelId,
errors: result.error.format(),
},
'Failed to parse ControlNet model id'
);
return;
}
return result.data;
};

View File

@ -1,27 +0,0 @@
import { logger } from 'app/logging/logger';
import { zParameterIPAdapterModel } from 'features/parameters/types/parameterSchemas';
import type { IPAdapterModelField } from 'services/api/types';
export const modelIdToIPAdapterModelParam = (ipAdapterModelId: string): IPAdapterModelField | undefined => {
const log = logger('models');
const [base_model, _model_type, model_name] = ipAdapterModelId.split('/');
const result = zParameterIPAdapterModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
log.error(
{
ipAdapterModelId,
errors: result.error.format(),
},
'Failed to parse IP-Adapter model id'
);
return;
}
return result.data;
};

View File

@ -1,27 +0,0 @@
import { logger } from 'app/logging/logger';
import type { ParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
import { zParameterLoRAModel } from 'features/parameters/types/parameterSchemas';
export const modelIdToLoRAModelParam = (loraModelId: string): ParameterLoRAModel | undefined => {
const log = logger('models');
const [base_model, _model_type, model_name] = loraModelId.split('/');
const result = zParameterLoRAModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
log.error(
{
loraModelId,
errors: result.error.format(),
},
'Failed to parse LoRA model id'
);
return;
}
return result.data;
};

View File

@ -1,27 +0,0 @@
import { logger } from 'app/logging/logger';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
export const modelIdToMainModelParam = (mainModelId: string): ParameterModel | undefined => {
const log = logger('models');
const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zParameterModel.safeParse({
base_model,
model_name,
model_type,
});
if (!result.success) {
log.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return;
}
return result.data;
};

View File

@ -1,27 +0,0 @@
import { logger } from 'app/logging/logger';
import type { ParameterSDXLRefinerModel } from 'features/parameters/types/parameterSchemas';
import { zParameterSDXLRefinerModel } from 'features/parameters/types/parameterSchemas';
export const modelIdToSDXLRefinerModelParam = (mainModelId: string): ParameterSDXLRefinerModel | undefined => {
const log = logger('models');
const [base_model, model_type, model_name] = mainModelId.split('/');
const result = zParameterSDXLRefinerModel.safeParse({
base_model,
model_name,
model_type,
});
if (!result.success) {
log.error(
{
mainModelId,
errors: result.error.format(),
},
'Failed to parse main model id'
);
return;
}
return result.data;
};

View File

@ -1,27 +0,0 @@
import { logger } from 'app/logging/logger';
import { zParameterT2IAdapterModel } from 'features/parameters/types/parameterSchemas';
import type { T2IAdapterModelField } from 'services/api/types';
export const modelIdToT2IAdapterModelParam = (t2iAdapterModelId: string): T2IAdapterModelField | undefined => {
const log = logger('models');
const [base_model, _model_type, model_name] = t2iAdapterModelId.split('/');
const result = zParameterT2IAdapterModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
log.error(
{
t2iAdapterModelId,
errors: result.error.format(),
},
'Failed to parse T2I-Adapter model id'
);
return;
}
return result.data;
};

View File

@ -1,26 +0,0 @@
import { logger } from 'app/logging/logger';
import type { ParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
export const modelIdToVAEModelParam = (vaeModelId: string): ParameterVAEModel | undefined => {
const log = logger('models');
const [base_model, _model_type, model_name] = vaeModelId.split('/');
const result = zParameterVAEModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
log.error(
{
vaeModelId,
errors: result.error.format(),
},
'Failed to parse VAE model id'
);
return;
}
return result.data;
};

View File

@ -2,8 +2,8 @@ import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useModelCombobox } from 'common/hooks/useModelCombobox';
import { getModelKeyAndBase } from 'features/metadata/util/modelFetchingHelpers';
import { refinerModelChanged, selectSdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { pick } from 'lodash-es';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { REFINER_BASE_MODELS } from 'services/api/constants';
@ -25,7 +25,7 @@ const ParamSDXLRefinerModelSelect = () => {
dispatch(refinerModelChanged(null));
return;
}
dispatch(refinerModelChanged(pick(model, ['key', 'base'])));
dispatch(refinerModelChanged(getModelKeyAndBase(model)));
},
[dispatch]
);