fix(ui): fix remaining TS issues

This commit is contained in:
psychedelicious 2024-02-27 15:31:40 +11:00
parent 97ecd99b9c
commit aaeef03593
8 changed files with 56 additions and 93 deletions

View File

@ -148,11 +148,11 @@ export const AdvancedImport = () => {
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
<ModelFormatSelect control={control} name="format" />
</FormControl>
</Flex>
<Flex gap={4}>

View File

@ -14,7 +14,7 @@ const options: ComboboxOption[] = [
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];
const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(

View File

@ -1,13 +1,12 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { LORA_MODEL_FORMAT_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field, formState } = useController(props);
const type = useWatch({ control: props.control, name: 'type' });
@ -21,10 +20,10 @@ const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T
const options: ComboboxOption[] = useMemo(() => {
const modelType = type || formState.defaultValues?.type;
if (modelType === 'lora') {
return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({
value: format,
label: LORA_MODEL_FORMAT_MAP[format],
})) as ComboboxOption[];
return [
{ value: 'lycoris', label: 'LyCORIS' },
{ value: 'diffusers', label: 'Diffusers' },
];
} else if (modelType === 'embedding') {
return [
{ value: 'embedding_file', label: 'Embedding File' },

View File

@ -15,7 +15,7 @@ const options: ComboboxOption[] = [
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
];
] as const
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);

View File

@ -14,22 +14,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useMemo } from 'react';
import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
import type {
AnyModelConfig,
CheckpointModelConfig,
ControlNetModelConfig,
DiffusersModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
} from 'services/api/types';
import type { AnyModelConfig } from 'services/api/types';
import BaseModelSelect from './Fields/BaseModelSelect';
import BooleanSelect from './Fields/BooleanSelect';
@ -48,38 +39,38 @@ export const ModelEdit = () => {
const { t } = useTranslation();
const modelData = useMemo(() => {
if (!data) {
return null;
}
const modelFormat = data.format;
const modelType = data.type;
// const modelData = useMemo(() => {
// if (!data) {
// return null;
// }
// const modelFormat = data.format;
// const modelType = data.type;
if (modelType === 'main') {
if (modelFormat === 'diffusers') {
return data as DiffusersModelConfig;
} else if (modelFormat === 'checkpoint') {
return data as CheckpointModelConfig;
}
}
// if (modelType === 'main') {
// if (modelFormat === 'diffusers') {
// return data as DiffusersModelConfig;
// } else if (modelFormat === 'checkpoint') {
// return data as CheckpointModelConfig;
// }
// }
switch (modelType) {
case 'lora':
return data as LoRAModelConfig;
case 'embedding':
return data as TextualInversionModelConfig;
case 't2i_adapter':
return data as T2IAdapterModelConfig;
case 'ip_adapter':
return data as IPAdapterModelConfig;
case 'controlnet':
return data as ControlNetModelConfig;
case 'vae':
return data as VAEModelConfig;
default:
return null;
}
}, [data]);
// switch (modelType) {
// case 'lora':
// return data as LoRAModelConfig;
// case 'embedding':
// return data as TextualInversionModelConfig;
// case 't2i_adapter':
// return data as T2IAdapterModelConfig;
// case 'ip_adapter':
// return data as IPAdapterModelConfig;
// case 'controlnet':
// return data as ControlNetModelConfig;
// case 'vae':
// return data as VAEModelConfig;
// default:
// return null;
// }
// }, [data]);
const {
register,
@ -88,9 +79,9 @@ export const ModelEdit = () => {
formState: { errors },
reset,
watch,
} = useForm<AnyModelConfig>({
} = useForm<UpdateModelArg['body']>({
defaultValues: {
...modelData,
...data,
},
mode: 'onChange',
});
@ -100,19 +91,19 @@ export const ModelEdit = () => {
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
if (!modelData?.key) {
if (!data?.key) {
return;
}
const responseBody = {
key: modelData.key,
const responseBody: UpdateModelArg = {
key: data.key,
body: values,
};
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as AnyModelConfig, { keepDefaultValues: true });
reset(payload, { keepDefaultValues: true });
dispatch(setSelectedModelMode('view'));
dispatch(
addToast(
@ -135,7 +126,7 @@ export const ModelEdit = () => {
);
});
},
[dispatch, modelData?.key, reset, t, updateModel]
[dispatch, data?.key, reset, t, updateModel]
);
const handleClickCancel = useCallback(() => {
@ -146,7 +137,7 @@ export const ModelEdit = () => {
return <Text>{t('common.loading')}</Text>;
}
if (!modelData) {
if (!data) {
return <Text>{t('common.somethingWentWrong')}</Text>;
}
return (
@ -193,7 +184,7 @@ export const ModelEdit = () => {
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
@ -203,7 +194,7 @@ export const ModelEdit = () => {
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
<ModelFormatSelect control={control} name="format" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>

View File

@ -1,5 +1,4 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import type { LoRAModelFormat } from 'services/api/types';
/**
* Mapping of base model to human readable name
@ -49,16 +48,6 @@ export const CLIP_SKIP_MAP = {
},
};
/**
* Mapping of LoRA format to human readable name
*/
export const LORA_MODEL_FORMAT_MAP: {
[key in LoRAModelFormat]: string;
} = {
lycoris: 'LyCORIS',
diffusers: 'Diffusers',
};
/**
* Mapping of schedulers to human readable name
*/

View File

@ -19,7 +19,7 @@ import type {
import type { ApiTagDescription, tagTypes } from '..';
import { api, buildV2Url, LIST_TAG } from '..';
type UpdateModelArg = {
export type UpdateModelArg = {
key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key'];
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
};

View File

@ -50,7 +50,6 @@ export type BaseModelType = S['BaseModelType'];
export type MainModelField = S['MainModelField'];
export type VAEModelField = S['VAEModelField'];
export type LoRAModelField = S['LoRAModelField'];
export type LoRAModelFormat = S['LoRAModelFormat'];
export type ControlNetModelField = S['ControlNetModelField'];
export type IPAdapterModelField = S['IPAdapterModelField'];
export type T2IAdapterModelField = S['T2IAdapterModelField'];
@ -72,8 +71,6 @@ export type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type RefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'sdxl-refiner' };
export type NonRefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' };
export type AnyModelConfig =
| LoRAModelConfig
| VAEModelConfig
@ -81,20 +78,9 @@ export type AnyModelConfig =
| IPAdapterModelConfig
| T2IAdapterModelConfig
| TextualInversionModelConfig
| RefinerMainModelConfig
| NonRefinerMainModelConfig
| MainModelConfig
| CLIPVisionDiffusersConfig;
type AnyModelConfig2 =
| (S['MainDiffusersConfig'] | S['MainCheckpointConfig'])
| (S['VaeDiffusersConfig'] | S['VaeCheckpointConfig'])
| (S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'])
| S['LoRAConfig']
| S['TextualInversionConfig']
| S['IPAdapterConfig']
| S['CLIPVisionDiffusersConfig']
| S['T2IConfig'];
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
return config.type === 'lora';
};
@ -119,16 +105,14 @@ export const isTextualInversionModelConfig = (config: AnyModelConfig): config is
return config.type === 'embedding';
};
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is NonRefinerMainModelConfig => {
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sdxl-refiner';
};
export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is RefinerMainModelConfig => {
export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sdxl-refiner';
};
export type MergeModelConfig = S['Body_merge'];
export type ImportModelConfig = S['Body_import_model'];
export type ModelInstallJob = S['ModelInstallJob'];
export type ModelInstallStatus = S['InstallStatus'];