fix(ui): fix remaining TS issues

This commit is contained in:
psychedelicious 2024-02-27 15:31:40 +11:00 committed by Kent Keirsey
parent ca00fabd79
commit cc229c3ea0
8 changed files with 56 additions and 93 deletions

View File

@ -148,11 +148,11 @@ export const AdvancedImport = () => {
<Flex gap={4}> <Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel> <FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect<AnyModelConfig> control={control} name="base" /> <BaseModelSelect control={control} name="base" />
</FormControl> </FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel> <FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect<AnyModelConfig> control={control} name="format" /> <ModelFormatSelect control={control} name="format" />
</FormControl> </FormControl>
</Flex> </Flex>
<Flex gap={4}> <Flex gap={4}>

View File

@ -14,7 +14,7 @@ const options: ComboboxOption[] = [
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] }, { value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
]; ];
const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => { const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field } = useController(props); const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>( const onChange = useCallback<ComboboxOnChange>(

View File

@ -1,13 +1,12 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library'; import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { LORA_MODEL_FORMAT_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form'; import type { UseControllerProps } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form'; import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => { const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field, formState } = useController(props); const { field, formState } = useController(props);
const type = useWatch({ control: props.control, name: 'type' }); const type = useWatch({ control: props.control, name: 'type' });
@ -21,10 +20,10 @@ const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T
const options: ComboboxOption[] = useMemo(() => { const options: ComboboxOption[] = useMemo(() => {
const modelType = type || formState.defaultValues?.type; const modelType = type || formState.defaultValues?.type;
if (modelType === 'lora') { if (modelType === 'lora') {
return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({ return [
value: format, { value: 'lycoris', label: 'LyCORIS' },
label: LORA_MODEL_FORMAT_MAP[format], { value: 'diffusers', label: 'Diffusers' },
})) as ComboboxOption[]; ];
} else if (modelType === 'embedding') { } else if (modelType === 'embedding') {
return [ return [
{ value: 'embedding_file', label: 'Embedding File' }, { value: 'embedding_file', label: 'Embedding File' },

View File

@ -15,7 +15,7 @@ const options: ComboboxOption[] = [
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string }, { value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string }, { value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string }, { value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
]; ] as const
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => { const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props); const { field } = useController(props);

View File

@ -14,22 +14,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice'; import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { useCallback, useMemo } from 'react'; import { useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models'; import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { import type { AnyModelConfig } from 'services/api/types';
AnyModelConfig,
CheckpointModelConfig,
ControlNetModelConfig,
DiffusersModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
} from 'services/api/types';
import BaseModelSelect from './Fields/BaseModelSelect'; import BaseModelSelect from './Fields/BaseModelSelect';
import BooleanSelect from './Fields/BooleanSelect'; import BooleanSelect from './Fields/BooleanSelect';
@ -48,38 +39,38 @@ export const ModelEdit = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const modelData = useMemo(() => { // const modelData = useMemo(() => {
if (!data) { // if (!data) {
return null; // return null;
} // }
const modelFormat = data.format; // const modelFormat = data.format;
const modelType = data.type; // const modelType = data.type;
if (modelType === 'main') { // if (modelType === 'main') {
if (modelFormat === 'diffusers') { // if (modelFormat === 'diffusers') {
return data as DiffusersModelConfig; // return data as DiffusersModelConfig;
} else if (modelFormat === 'checkpoint') { // } else if (modelFormat === 'checkpoint') {
return data as CheckpointModelConfig; // return data as CheckpointModelConfig;
} // }
} // }
switch (modelType) { // switch (modelType) {
case 'lora': // case 'lora':
return data as LoRAModelConfig; // return data as LoRAModelConfig;
case 'embedding': // case 'embedding':
return data as TextualInversionModelConfig; // return data as TextualInversionModelConfig;
case 't2i_adapter': // case 't2i_adapter':
return data as T2IAdapterModelConfig; // return data as T2IAdapterModelConfig;
case 'ip_adapter': // case 'ip_adapter':
return data as IPAdapterModelConfig; // return data as IPAdapterModelConfig;
case 'controlnet': // case 'controlnet':
return data as ControlNetModelConfig; // return data as ControlNetModelConfig;
case 'vae': // case 'vae':
return data as VAEModelConfig; // return data as VAEModelConfig;
default: // default:
return null; // return null;
} // }
}, [data]); // }, [data]);
const { const {
register, register,
@ -88,9 +79,9 @@ export const ModelEdit = () => {
formState: { errors }, formState: { errors },
reset, reset,
watch, watch,
} = useForm<AnyModelConfig>({ } = useForm<UpdateModelArg['body']>({
defaultValues: { defaultValues: {
...modelData, ...data,
}, },
mode: 'onChange', mode: 'onChange',
}); });
@ -100,19 +91,19 @@ export const ModelEdit = () => {
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>( const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => { (values) => {
if (!modelData?.key) { if (!data?.key) {
return; return;
} }
const responseBody = { const responseBody: UpdateModelArg = {
key: modelData.key, key: data.key,
body: values, body: values,
}; };
updateModel(responseBody) updateModel(responseBody)
.unwrap() .unwrap()
.then((payload) => { .then((payload) => {
reset(payload as AnyModelConfig, { keepDefaultValues: true }); reset(payload, { keepDefaultValues: true });
dispatch(setSelectedModelMode('view')); dispatch(setSelectedModelMode('view'));
dispatch( dispatch(
addToast( addToast(
@ -135,7 +126,7 @@ export const ModelEdit = () => {
); );
}); });
}, },
[dispatch, modelData?.key, reset, t, updateModel] [dispatch, data?.key, reset, t, updateModel]
); );
const handleClickCancel = useCallback(() => { const handleClickCancel = useCallback(() => {
@ -146,7 +137,7 @@ export const ModelEdit = () => {
return <Text>{t('common.loading')}</Text>; return <Text>{t('common.loading')}</Text>;
} }
if (!modelData) { if (!data) {
return <Text>{t('common.somethingWentWrong')}</Text>; return <Text>{t('common.somethingWentWrong')}</Text>;
} }
return ( return (
@ -193,7 +184,7 @@ export const ModelEdit = () => {
<Flex gap={4}> <Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel> <FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect<AnyModelConfig> control={control} name="base" /> <BaseModelSelect control={control} name="base" />
</FormControl> </FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel> <FormLabel>{t('modelManager.modelType')}</FormLabel>
@ -203,7 +194,7 @@ export const ModelEdit = () => {
<Flex gap={4}> <Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel> <FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect<AnyModelConfig> control={control} name="format" /> <ModelFormatSelect control={control} name="format" />
</FormControl> </FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}> <FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel> <FormLabel>{t('modelManager.path')}</FormLabel>

View File

@ -1,5 +1,4 @@
import type { ComboboxOption } from '@invoke-ai/ui-library'; import type { ComboboxOption } from '@invoke-ai/ui-library';
import type { LoRAModelFormat } from 'services/api/types';
/** /**
* Mapping of base model to human readable name * Mapping of base model to human readable name
@ -49,16 +48,6 @@ export const CLIP_SKIP_MAP = {
}, },
}; };
/**
* Mapping of LoRA format to human readable name
*/
export const LORA_MODEL_FORMAT_MAP: {
[key in LoRAModelFormat]: string;
} = {
lycoris: 'LyCORIS',
diffusers: 'Diffusers',
};
/** /**
* Mapping of schedulers to human readable name * Mapping of schedulers to human readable name
*/ */

View File

@ -19,7 +19,7 @@ import type {
import type { ApiTagDescription, tagTypes } from '..'; import type { ApiTagDescription, tagTypes } from '..';
import { api, buildV2Url, LIST_TAG } from '..'; import { api, buildV2Url, LIST_TAG } from '..';
type UpdateModelArg = { export type UpdateModelArg = {
key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key']; key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key'];
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json']; body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
}; };

View File

@ -50,7 +50,6 @@ export type BaseModelType = S['BaseModelType'];
export type MainModelField = S['MainModelField']; export type MainModelField = S['MainModelField'];
export type VAEModelField = S['VAEModelField']; export type VAEModelField = S['VAEModelField'];
export type LoRAModelField = S['LoRAModelField']; export type LoRAModelField = S['LoRAModelField'];
export type LoRAModelFormat = S['LoRAModelFormat'];
export type ControlNetModelField = S['ControlNetModelField']; export type ControlNetModelField = S['ControlNetModelField'];
export type IPAdapterModelField = S['IPAdapterModelField']; export type IPAdapterModelField = S['IPAdapterModelField'];
export type T2IAdapterModelField = S['T2IAdapterModelField']; export type T2IAdapterModelField = S['T2IAdapterModelField'];
@ -72,8 +71,6 @@ export type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig']; export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig']; type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig; export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type RefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'sdxl-refiner' };
export type NonRefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' };
export type AnyModelConfig = export type AnyModelConfig =
| LoRAModelConfig | LoRAModelConfig
| VAEModelConfig | VAEModelConfig
@ -81,20 +78,9 @@ export type AnyModelConfig =
| IPAdapterModelConfig | IPAdapterModelConfig
| T2IAdapterModelConfig | T2IAdapterModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| RefinerMainModelConfig | MainModelConfig
| NonRefinerMainModelConfig
| CLIPVisionDiffusersConfig; | CLIPVisionDiffusersConfig;
type AnyModelConfig2 =
| (S['MainDiffusersConfig'] | S['MainCheckpointConfig'])
| (S['VaeDiffusersConfig'] | S['VaeCheckpointConfig'])
| (S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'])
| S['LoRAConfig']
| S['TextualInversionConfig']
| S['IPAdapterConfig']
| S['CLIPVisionDiffusersConfig']
| S['T2IConfig'];
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => { export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
return config.type === 'lora'; return config.type === 'lora';
}; };
@ -119,16 +105,14 @@ export const isTextualInversionModelConfig = (config: AnyModelConfig): config is
return config.type === 'embedding'; return config.type === 'embedding';
}; };
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is NonRefinerMainModelConfig => { export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sdxl-refiner'; return config.type === 'main' && config.base !== 'sdxl-refiner';
}; };
export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is RefinerMainModelConfig => { export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sdxl-refiner'; return config.type === 'main' && config.base === 'sdxl-refiner';
}; };
export type MergeModelConfig = S['Body_merge'];
export type ImportModelConfig = S['Body_import_model'];
export type ModelInstallJob = S['ModelInstallJob']; export type ModelInstallJob = S['ModelInstallJob'];
export type ModelInstallStatus = S['InstallStatus']; export type ModelInstallStatus = S['InstallStatus'];