mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(ui): fix remaining TS issues
This commit is contained in:
parent
97ecd99b9c
commit
aaeef03593
@ -148,11 +148,11 @@ export const AdvancedImport = () => {
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
|
||||
<BaseModelSelect control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('common.format')}</FormLabel>
|
||||
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
|
||||
<ModelFormatSelect control={control} name="format" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
|
@ -14,7 +14,7 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||
];
|
||||
|
||||
const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
|
@ -1,13 +1,12 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { LORA_MODEL_FORMAT_MAP } from 'features/parameters/types/constants';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController, useWatch } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field, formState } = useController(props);
|
||||
const type = useWatch({ control: props.control, name: 'type' });
|
||||
|
||||
@ -21,10 +20,10 @@ const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T
|
||||
const options: ComboboxOption[] = useMemo(() => {
|
||||
const modelType = type || formState.defaultValues?.type;
|
||||
if (modelType === 'lora') {
|
||||
return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({
|
||||
value: format,
|
||||
label: LORA_MODEL_FORMAT_MAP[format],
|
||||
})) as ComboboxOption[];
|
||||
return [
|
||||
{ value: 'lycoris', label: 'LyCORIS' },
|
||||
{ value: 'diffusers', label: 'Diffusers' },
|
||||
];
|
||||
} else if (modelType === 'embedding') {
|
||||
return [
|
||||
{ value: 'embedding_file', label: 'Embedding File' },
|
||||
|
@ -15,7 +15,7 @@ const options: ComboboxOption[] = [
|
||||
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
|
||||
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
|
||||
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
|
||||
];
|
||||
] as const
|
||||
|
||||
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
|
@ -14,22 +14,13 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setSelectedModelMode } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type {
|
||||
AnyModelConfig,
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
DiffusersModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
LoRAModelConfig,
|
||||
T2IAdapterModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VAEModelConfig,
|
||||
} from 'services/api/types';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import BooleanSelect from './Fields/BooleanSelect';
|
||||
@ -48,38 +39,38 @@ export const ModelEdit = () => {
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const modelData = useMemo(() => {
|
||||
if (!data) {
|
||||
return null;
|
||||
}
|
||||
const modelFormat = data.format;
|
||||
const modelType = data.type;
|
||||
// const modelData = useMemo(() => {
|
||||
// if (!data) {
|
||||
// return null;
|
||||
// }
|
||||
// const modelFormat = data.format;
|
||||
// const modelType = data.type;
|
||||
|
||||
if (modelType === 'main') {
|
||||
if (modelFormat === 'diffusers') {
|
||||
return data as DiffusersModelConfig;
|
||||
} else if (modelFormat === 'checkpoint') {
|
||||
return data as CheckpointModelConfig;
|
||||
}
|
||||
}
|
||||
// if (modelType === 'main') {
|
||||
// if (modelFormat === 'diffusers') {
|
||||
// return data as DiffusersModelConfig;
|
||||
// } else if (modelFormat === 'checkpoint') {
|
||||
// return data as CheckpointModelConfig;
|
||||
// }
|
||||
// }
|
||||
|
||||
switch (modelType) {
|
||||
case 'lora':
|
||||
return data as LoRAModelConfig;
|
||||
case 'embedding':
|
||||
return data as TextualInversionModelConfig;
|
||||
case 't2i_adapter':
|
||||
return data as T2IAdapterModelConfig;
|
||||
case 'ip_adapter':
|
||||
return data as IPAdapterModelConfig;
|
||||
case 'controlnet':
|
||||
return data as ControlNetModelConfig;
|
||||
case 'vae':
|
||||
return data as VAEModelConfig;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}, [data]);
|
||||
// switch (modelType) {
|
||||
// case 'lora':
|
||||
// return data as LoRAModelConfig;
|
||||
// case 'embedding':
|
||||
// return data as TextualInversionModelConfig;
|
||||
// case 't2i_adapter':
|
||||
// return data as T2IAdapterModelConfig;
|
||||
// case 'ip_adapter':
|
||||
// return data as IPAdapterModelConfig;
|
||||
// case 'controlnet':
|
||||
// return data as ControlNetModelConfig;
|
||||
// case 'vae':
|
||||
// return data as VAEModelConfig;
|
||||
// default:
|
||||
// return null;
|
||||
// }
|
||||
// }, [data]);
|
||||
|
||||
const {
|
||||
register,
|
||||
@ -88,9 +79,9 @@ export const ModelEdit = () => {
|
||||
formState: { errors },
|
||||
reset,
|
||||
watch,
|
||||
} = useForm<AnyModelConfig>({
|
||||
} = useForm<UpdateModelArg['body']>({
|
||||
defaultValues: {
|
||||
...modelData,
|
||||
...data,
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
@ -100,19 +91,19 @@ export const ModelEdit = () => {
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
||||
(values) => {
|
||||
if (!modelData?.key) {
|
||||
if (!data?.key) {
|
||||
return;
|
||||
}
|
||||
|
||||
const responseBody = {
|
||||
key: modelData.key,
|
||||
const responseBody: UpdateModelArg = {
|
||||
key: data.key,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload as AnyModelConfig, { keepDefaultValues: true });
|
||||
reset(payload, { keepDefaultValues: true });
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
dispatch(
|
||||
addToast(
|
||||
@ -135,7 +126,7 @@ export const ModelEdit = () => {
|
||||
);
|
||||
});
|
||||
},
|
||||
[dispatch, modelData?.key, reset, t, updateModel]
|
||||
[dispatch, data?.key, reset, t, updateModel]
|
||||
);
|
||||
|
||||
const handleClickCancel = useCallback(() => {
|
||||
@ -146,7 +137,7 @@ export const ModelEdit = () => {
|
||||
return <Text>{t('common.loading')}</Text>;
|
||||
}
|
||||
|
||||
if (!modelData) {
|
||||
if (!data) {
|
||||
return <Text>{t('common.somethingWentWrong')}</Text>;
|
||||
}
|
||||
return (
|
||||
@ -193,7 +184,7 @@ export const ModelEdit = () => {
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
|
||||
<BaseModelSelect control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
@ -203,7 +194,7 @@ export const ModelEdit = () => {
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('common.format')}</FormLabel>
|
||||
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
|
||||
<ModelFormatSelect control={control} name="format" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
||||
|
@ -1,5 +1,4 @@
|
||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { LoRAModelFormat } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Mapping of base model to human readable name
|
||||
@ -49,16 +48,6 @@ export const CLIP_SKIP_MAP = {
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Mapping of LoRA format to human readable name
|
||||
*/
|
||||
export const LORA_MODEL_FORMAT_MAP: {
|
||||
[key in LoRAModelFormat]: string;
|
||||
} = {
|
||||
lycoris: 'LyCORIS',
|
||||
diffusers: 'Diffusers',
|
||||
};
|
||||
|
||||
/**
|
||||
* Mapping of schedulers to human readable name
|
||||
*/
|
||||
|
@ -19,7 +19,7 @@ import type {
|
||||
import type { ApiTagDescription, tagTypes } from '..';
|
||||
import { api, buildV2Url, LIST_TAG } from '..';
|
||||
|
||||
type UpdateModelArg = {
|
||||
export type UpdateModelArg = {
|
||||
key: paths['/api/v2/models/i/{key}']['patch']['parameters']['path']['key'];
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
@ -50,7 +50,6 @@ export type BaseModelType = S['BaseModelType'];
|
||||
export type MainModelField = S['MainModelField'];
|
||||
export type VAEModelField = S['VAEModelField'];
|
||||
export type LoRAModelField = S['LoRAModelField'];
|
||||
export type LoRAModelFormat = S['LoRAModelFormat'];
|
||||
export type ControlNetModelField = S['ControlNetModelField'];
|
||||
export type IPAdapterModelField = S['IPAdapterModelField'];
|
||||
export type T2IAdapterModelField = S['T2IAdapterModelField'];
|
||||
@ -72,8 +71,6 @@ export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
|
||||
export type RefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'sdxl-refiner' };
|
||||
export type NonRefinerMainModelConfig = Omit<MainModelConfig, 'base'> & { base: 'any' | 'sd-1' | 'sd-2' | 'sdxl' };
|
||||
export type AnyModelConfig =
|
||||
| LoRAModelConfig
|
||||
| VAEModelConfig
|
||||
@ -81,20 +78,9 @@ export type AnyModelConfig =
|
||||
| IPAdapterModelConfig
|
||||
| T2IAdapterModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| RefinerMainModelConfig
|
||||
| NonRefinerMainModelConfig
|
||||
| MainModelConfig
|
||||
| CLIPVisionDiffusersConfig;
|
||||
|
||||
type AnyModelConfig2 =
|
||||
| (S['MainDiffusersConfig'] | S['MainCheckpointConfig'])
|
||||
| (S['VaeDiffusersConfig'] | S['VaeCheckpointConfig'])
|
||||
| (S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'])
|
||||
| S['LoRAConfig']
|
||||
| S['TextualInversionConfig']
|
||||
| S['IPAdapterConfig']
|
||||
| S['CLIPVisionDiffusersConfig']
|
||||
| S['T2IConfig'];
|
||||
|
||||
export const isLoRAModelConfig = (config: AnyModelConfig): config is LoRAModelConfig => {
|
||||
return config.type === 'lora';
|
||||
};
|
||||
@ -119,16 +105,14 @@ export const isTextualInversionModelConfig = (config: AnyModelConfig): config is
|
||||
return config.type === 'embedding';
|
||||
};
|
||||
|
||||
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is NonRefinerMainModelConfig => {
|
||||
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base !== 'sdxl-refiner';
|
||||
};
|
||||
|
||||
export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is RefinerMainModelConfig => {
|
||||
export const isRefinerMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'sdxl-refiner';
|
||||
};
|
||||
|
||||
export type MergeModelConfig = S['Body_merge'];
|
||||
export type ImportModelConfig = S['Body_import_model'];
|
||||
export type ModelInstallJob = S['ModelInstallJob'];
|
||||
export type ModelInstallStatus = S['InstallStatus'];
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user