tidy(ui): tidy model identifier logic

- Move some files around
- Use util to extract key and base from model config
This commit is contained in:
psychedelicious
2024-02-26 13:40:52 +11:00
parent 7176c5d9d6
commit a253047d8e
19 changed files with 21 additions and 256 deletions

View File

@ -0,0 +1,131 @@
import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierWithBase } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
/**
* Raised when a model config is unable to be fetched.
*/
export class ModelConfigNotFoundError extends Error {
/**
* Create ModelConfigNotFoundError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
/**
* Raised when a fetched model config is of an unexpected type.
*/
export class InvalidModelConfigError extends Error {
/**
* Create InvalidModelConfigError
* @param {String} message
*/
constructor(message: string) {
super(message);
this.name = this.constructor.name;
}
}
/**
* Fetches the model config for a given model key.
* @param key The model key.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
req.unsubscribe();
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`);
}
};
/**
* Fetches the model config for a given model name, base model, and model type. This provides backwards compatibility
* for MM1 model identifiers.
* @param name The model name.
* @param base The base model.
* @param type The model type.
* @returns A promise that resolves to the model config.
* @throws {ModelConfigNotFoundError} If the model config is unable to be fetched.
*/
export const fetchModelConfigByAttrs = async (
name: string,
base: BaseModelType,
type: ModelType
): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }));
req.unsubscribe();
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);
}
};
/**
* Fetches the model config for a given model key and type, and ensures that the model config is of a specific type.
* @param key The model key.
* @param typeGuard A type guard function that checks if the model config is of the expected type.
* @returns A promise that resolves to the model config. The model config is guaranteed to be of the expected type.
* @throws {InvalidModelConfigError} If the model config is unable to be fetched or is of an unexpected type.
*/
export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
key: string,
typeGuard: (config: AnyModelConfig) => config is T
) => {
const modelConfig = await fetchModelConfig(key);
if (!typeGuard(modelConfig)) {
throw new InvalidModelConfigError(`Invalid model type for key ${key}: ${modelConfig.type}`);
}
return modelConfig;
};
/**
* Raises an error if the source base model is incompatible with the target base model.
* @param sourceBase The source base model.
* @param targetBase The target base model.
* @param message An optional custom message to include in the error.
* @throws {InvalidModelConfigError} If the source base model is incompatible with the target base model.
*/
export const raiseIfBaseIncompatible = (sourceBase: BaseModelType, targetBase?: BaseModelType, message?: string) => {
if (targetBase && sourceBase !== targetBase) {
throw new InvalidModelConfigError(message || `Incompatible base models: ${sourceBase} and ${targetBase}`);
}
};
/**
* Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers.
* @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format
* `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key.
* @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers.
* @param message An optional custom message to include in the error if the model identifier is invalid.
* @returns A promise that resolves to the model key.
* @throws {InvalidModelConfigError} If the model identifier is invalid.
*/
export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise<string> => {
if (isModelIdentifier(modelIdentifier)) {
return modelIdentifier.key;
}
if (isModelIdentifierV2(modelIdentifier)) {
return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key;
}
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
};
export const getModelKeyAndBase = (modelConfig: AnyModelConfig): ModelIdentifierWithBase => ({
key: modelConfig.key,
base: modelConfig.base,
});

View File

@ -9,6 +9,11 @@ import type { LoRA } from 'features/lora/store/loraSlice';
import { defaultLoRAConfig } from 'features/lora/store/loraSlice';
import { MetadataParseError } from 'features/metadata/exceptions';
import type { MetadataParseFunc } from 'features/metadata/types';
import {
fetchModelConfigWithTypeGuard,
getModelKey,
getModelKeyAndBase,
} from 'features/metadata/util/modelFetchingHelpers';
import {
zControlField,
zIPAdapterField,
@ -54,11 +59,6 @@ import {
isParameterStrength,
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import {
fetchModelConfigWithTypeGuard,
getModelKey,
getModelKeyAndBase,
} from 'features/parameters/util/modelFetchingHelpers';
import { get, isArray, isString } from 'lodash-es';
import type { NonRefinerMainModelConfig, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
import {

View File

@ -2,7 +2,7 @@ import { getStore } from 'app/store/nanostores/store';
import type { ControlNetConfig, IPAdapterConfig, T2IAdapterConfig } from 'features/controlAdapters/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type { MetadataValidateFunc } from 'features/metadata/types';
import { InvalidModelConfigError } from 'features/parameters/util/modelFetchingHelpers';
import { InvalidModelConfigError } from 'features/metadata/util/modelFetchingHelpers';
import type { BaseModelType, RefinerMainModelConfig, VAEModelConfig } from 'services/api/types';
/**