mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
get old UI working somewhat with new endpoints
This commit is contained in:
parent
09295ae43b
commit
bdc2b8069b
@ -1,15 +1,18 @@
|
||||
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
|
||||
import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import AdvancedAddModels from './AdvancedAddModels';
|
||||
import SimpleAddModels from './SimpleAddModels';
|
||||
import { useGetModelImportsQuery } from '../../../../services/api/endpoints/models';
|
||||
|
||||
const AddModels = () => {
|
||||
const { t } = useTranslation();
|
||||
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>('simple');
|
||||
const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []);
|
||||
const handleAddModelAdvanced = useCallback(() => setAddModelMode('advanced'), []);
|
||||
const { data } = useGetModelImportsQuery({});
|
||||
console.log({ data });
|
||||
return (
|
||||
<Flex flexDirection="column" width="100%" overflow="scroll" maxHeight={window.innerHeight - 250} gap={4}>
|
||||
<ButtonGroup>
|
||||
@ -24,6 +27,7 @@ const AddModels = () => {
|
||||
{addModelMode === 'simple' && <SimpleAddModels />}
|
||||
{addModelMode === 'advanced' && <AdvancedAddModels />}
|
||||
</Flex>
|
||||
<Flex>{data?.map((model) => <Text>{model.status}</Text>)}</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -36,11 +36,10 @@ const SimpleAddModels = () => {
|
||||
|
||||
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
|
||||
const importModelResponseBody = {
|
||||
location: values.location,
|
||||
prediction_type: values.prediction_type === 'none' ? undefined : values.prediction_type,
|
||||
config: values.prediction_type === 'none' ? undefined : values.prediction_type,
|
||||
};
|
||||
|
||||
importMainModel({ body: importModelResponseBody })
|
||||
importMainModel({ source: values.location, config: importModelResponseBody })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
|
@ -2,13 +2,13 @@ import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { DiffusersModelConfig, LoRAConfig, MainModelConfig } from '../../../services/api/types';
|
||||
|
||||
const ModelManagerPanel = () => {
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||
@ -41,16 +41,16 @@ const ModelEdit = (props: ModelEditProps) => {
|
||||
const { t } = useTranslation();
|
||||
const { model } = props;
|
||||
|
||||
if (model?.model_format === 'checkpoint') {
|
||||
return <CheckpointModelEdit key={model.id} model={model} />;
|
||||
if (model?.format === 'checkpoint') {
|
||||
return <CheckpointModelEdit key={model.key} model={model} />;
|
||||
}
|
||||
|
||||
if (model?.model_format === 'diffusers') {
|
||||
return <DiffusersModelEdit key={model.id} model={model as DiffusersModelConfig} />;
|
||||
if (model?.format === 'diffusers') {
|
||||
return <DiffusersModelEdit key={model.key} model={model as DiffusersModelConfig} />;
|
||||
}
|
||||
|
||||
if (model?.model_type === 'lora') {
|
||||
return <LoRAModelEdit key={model.id} model={model} />;
|
||||
if (model?.type === 'lora') {
|
||||
return <LoRAModelEdit key={model.key} model={model} />;
|
||||
}
|
||||
|
||||
return (
|
||||
|
@ -21,11 +21,9 @@ import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { CheckpointModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import ModelConvert from './ModelConvert';
|
||||
import { CheckpointModelConfig } from '../../../../services/api/types';
|
||||
|
||||
type CheckpointModelEditProps = {
|
||||
model: CheckpointModelConfig;
|
||||
@ -34,7 +32,7 @@ type CheckpointModelEditProps = {
|
||||
const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
|
||||
const { model } = props;
|
||||
|
||||
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
|
||||
const [updateModel, { isLoading }] = useUpdateModelsMutation();
|
||||
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
|
||||
|
||||
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
||||
@ -56,12 +54,12 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
|
||||
reset,
|
||||
} = useForm<CheckpointModelConfig>({
|
||||
defaultValues: {
|
||||
model_name: model.model_name ? model.model_name : '',
|
||||
base_model: model.base_model,
|
||||
model_type: 'main',
|
||||
name: model.name ? model.name : '',
|
||||
base: model.base,
|
||||
type: 'main',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
model_format: 'checkpoint',
|
||||
format: 'checkpoint',
|
||||
vae: model.vae ? model.vae : '',
|
||||
config: model.config ? model.config : '',
|
||||
variant: model.variant,
|
||||
@ -74,11 +72,10 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
|
||||
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
|
||||
(values) => {
|
||||
const responseBody = {
|
||||
base_model: model.base_model,
|
||||
model_name: model.model_name,
|
||||
key: model.key,
|
||||
body: values,
|
||||
};
|
||||
updateMainModel(responseBody)
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload as CheckpointModelConfig, { keepDefaultValues: true });
|
||||
@ -103,7 +100,7 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
|
||||
);
|
||||
});
|
||||
},
|
||||
[dispatch, model.base_model, model.model_name, reset, t, updateMainModel]
|
||||
[dispatch, model.key, reset, t, updateModel]
|
||||
);
|
||||
|
||||
return (
|
||||
@ -111,13 +108,13 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{model.model_name}
|
||||
{model.name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')}
|
||||
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
|
||||
</Text>
|
||||
</Flex>
|
||||
{![''].includes(model.base_model) ? (
|
||||
{![''].includes(model.base) ? (
|
||||
<ModelConvert model={model} />
|
||||
) : (
|
||||
<Badge p={2} borderRadius={4} bg="error.400">
|
||||
@ -130,20 +127,20 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
|
||||
<Flex flexDirection="column" maxHeight={window.innerHeight - 270} overflowY="scroll">
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<FormControl isInvalid={Boolean(errors.model_name)}>
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('model_name', {
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<BaseModelSelect<CheckpointModelConfig> control={control} name="base_model" />
|
||||
<BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
|
||||
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
|
@ -9,9 +9,8 @@ import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { DiffusersModelConfig } from 'services/api/endpoints/models';
|
||||
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig } from 'services/api/types';
|
||||
import { useUpdateModelsMutation } from '../../../../services/api/endpoints/models';
|
||||
|
||||
type DiffusersModelEditProps = {
|
||||
model: DiffusersModelConfig;
|
||||
@ -20,7 +19,7 @@ type DiffusersModelEditProps = {
|
||||
const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
|
||||
const { model } = props;
|
||||
|
||||
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
|
||||
const [updateModel, { isLoading }] = useUpdateModelsMutation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
@ -33,12 +32,12 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
|
||||
reset,
|
||||
} = useForm<DiffusersModelConfig>({
|
||||
defaultValues: {
|
||||
model_name: model.model_name ? model.model_name : '',
|
||||
base_model: model.base_model,
|
||||
model_type: 'main',
|
||||
name: model.name ? model.name : '',
|
||||
base: model.base,
|
||||
type: 'main',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
model_format: 'diffusers',
|
||||
format: 'diffusers',
|
||||
vae: model.vae ? model.vae : '',
|
||||
variant: model.variant,
|
||||
},
|
||||
@ -48,12 +47,11 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
|
||||
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
|
||||
(values) => {
|
||||
const responseBody = {
|
||||
base_model: model.base_model,
|
||||
model_name: model.model_name,
|
||||
key: model.key,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateMainModel(responseBody)
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload as DiffusersModelConfig, { keepDefaultValues: true });
|
||||
@ -78,37 +76,37 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
|
||||
);
|
||||
});
|
||||
},
|
||||
[dispatch, model.base_model, model.model_name, reset, t, updateMainModel]
|
||||
[dispatch, model.key, reset, t, updateModel]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{model.model_name}
|
||||
{model.name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')}
|
||||
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Divider />
|
||||
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<FormControl isInvalid={Boolean(errors.model_name)}>
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('model_name', {
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<BaseModelSelect<DiffusersModelConfig> control={control} name="base_model" />
|
||||
<BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
|
||||
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
|
@ -8,7 +8,6 @@ import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
type LoRAModelEditProps = {
|
||||
@ -18,7 +17,7 @@ type LoRAModelEditProps = {
|
||||
const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
const { model } = props;
|
||||
|
||||
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation();
|
||||
const [updateModel, { isLoading }] = useUpdateModelsMutation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
@ -31,12 +30,12 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
reset,
|
||||
} = useForm<LoRAModelConfig>({
|
||||
defaultValues: {
|
||||
model_name: model.model_name ? model.model_name : '',
|
||||
base_model: model.base_model,
|
||||
model_type: 'lora',
|
||||
name: model.name ? model.name : '',
|
||||
base: model.base,
|
||||
type: 'lora',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
model_format: model.model_format,
|
||||
format: model.format,
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
@ -44,12 +43,11 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
|
||||
(values) => {
|
||||
const responseBody = {
|
||||
base_model: model.base_model,
|
||||
model_name: model.model_name,
|
||||
key: model.key,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateLoRAModel(responseBody)
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
|
||||
@ -74,17 +72,17 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
);
|
||||
});
|
||||
},
|
||||
[dispatch, model.base_model, model.model_name, reset, t, updateLoRAModel]
|
||||
[dispatch, model.key, reset, t, updateModel]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{model.model_name}
|
||||
{model.name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')} ⋅ {LORA_MODEL_FORMAT_MAP[model.model_format]}{' '}
|
||||
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')} ⋅ {LORA_MODEL_FORMAT_MAP[model.format]}{' '}
|
||||
{t('common.format')}
|
||||
</Text>
|
||||
</Flex>
|
||||
@ -92,20 +90,20 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<FormControl isInvalid={Boolean(errors.model_name)}>
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('model_name', {
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<BaseModelSelect<LoRAModelConfig> control={control} name="base_model" />
|
||||
<BaseModelSelect<LoRAModelConfig> control={control} name="base" />
|
||||
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
|
@ -54,8 +54,8 @@ const ModelConvert = (props: ModelConvertProps) => {
|
||||
|
||||
const modelConvertHandler = useCallback(() => {
|
||||
const queryArg = {
|
||||
base_model: model.base_model,
|
||||
model_name: model.model_name,
|
||||
base_model: model.base,
|
||||
model_name: model.name,
|
||||
convert_dest_directory: saveLocation === 'Custom' ? customSaveLocation : undefined,
|
||||
};
|
||||
|
||||
@ -74,7 +74,7 @@ const ModelConvert = (props: ModelConvertProps) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.convertingModelBegin')}: ${model.model_name}`,
|
||||
title: `${t('modelManager.convertingModelBegin')}: ${model.name}`,
|
||||
status: 'info',
|
||||
})
|
||||
)
|
||||
@ -86,7 +86,7 @@ const ModelConvert = (props: ModelConvertProps) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelConverted')}: ${model.model_name}`,
|
||||
title: `${t('modelManager.modelConverted')}: ${model.name}`,
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
@ -96,13 +96,13 @@ const ModelConvert = (props: ModelConvertProps) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelConversionFailed')}: ${model.model_name}`,
|
||||
title: `${t('modelManager.modelConversionFailed')}: ${model.name}`,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
}, [convertModel, customSaveLocation, dispatch, model.base_model, model.model_name, saveLocation, t]);
|
||||
}, [convertModel, customSaveLocation, dispatch, model.base, model.name, saveLocation, t]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -116,7 +116,7 @@ const ModelConvert = (props: ModelConvertProps) => {
|
||||
🧨 {t('modelManager.convertToDiffusers')}
|
||||
</Button>
|
||||
<ConfirmationAlertDialog
|
||||
title={`${t('modelManager.convert')} ${model.model_name}`}
|
||||
title={`${t('modelManager.convert')} ${model.name}`}
|
||||
acceptCallback={modelConvertHandler}
|
||||
cancelCallback={modelConvertCancelHandler}
|
||||
acceptButtonText={`${t('modelManager.convert')}`}
|
||||
|
@ -5,10 +5,11 @@ import type { ChangeEvent, PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
// import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
|
||||
|
||||
type ModelListProps = {
|
||||
selectedModelId: string | undefined;
|
||||
@ -177,9 +178,9 @@ const ModelListWrapper = memo((props: ModelListWrapperProps) => {
|
||||
</Text>
|
||||
{modelList.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
key={model.key}
|
||||
model={model}
|
||||
isSelected={selected.selectedModelId === model.id}
|
||||
isSelected={selected.selectedModelId === model.key}
|
||||
setSelectedModelId={selected.setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
|
@ -15,8 +15,8 @@ import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: MainModelConfig | LoRAConfig;
|
||||
@ -27,29 +27,23 @@ type ModelListItemProps = {
|
||||
const ModelListItem = (props: ModelListItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const [deleteMainModel] = useDeleteMainModelsMutation();
|
||||
const [deleteLoRAModel] = useDeleteLoRAModelsMutation();
|
||||
const [deleteModel] = useDeleteModelsMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const { model, isSelected, setSelectedModelId } = props;
|
||||
|
||||
const handleSelectModel = useCallback(() => {
|
||||
setSelectedModelId(model.id);
|
||||
}, [model.id, setSelectedModelId]);
|
||||
setSelectedModelId(model.key);
|
||||
}, [model.key, setSelectedModelId]);
|
||||
|
||||
const handleModelDelete = useCallback(() => {
|
||||
const method = {
|
||||
main: deleteMainModel,
|
||||
lora: deleteLoRAModel,
|
||||
}[model.model_type];
|
||||
|
||||
method(model)
|
||||
deleteModel({ key: model.key })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelDeleted')}: ${model.model_name}`,
|
||||
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
@ -60,7 +54,7 @@ const ModelListItem = (props: ModelListItemProps) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelDeleteFailed')}: ${model.model_name}`,
|
||||
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
@ -68,7 +62,7 @@ const ModelListItem = (props: ModelListItemProps) => {
|
||||
}
|
||||
});
|
||||
setSelectedModelId(undefined);
|
||||
}, [deleteMainModel, deleteLoRAModel, model, setSelectedModelId, dispatch, t]);
|
||||
}, [deleteModel, model, setSelectedModelId, dispatch, t]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
@ -85,10 +79,10 @@ const ModelListItem = (props: ModelListItemProps) => {
|
||||
>
|
||||
<Flex gap={4} alignItems="center">
|
||||
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
|
||||
{MODEL_TYPE_SHORT_MAP[model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP]}
|
||||
{MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]}
|
||||
</Badge>
|
||||
<Tooltip label={model.description} placement="bottom">
|
||||
<Text>{model.model_name}</Text>
|
||||
<Text>{model.name}</Text>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
@ -7,12 +7,10 @@ import type {
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ControlNetModelConfig,
|
||||
ImportModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
MergeModelConfig,
|
||||
ModelType,
|
||||
T2IAdapterModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VAEModelConfig,
|
||||
@ -21,37 +19,21 @@ import type {
|
||||
import type { ApiTagDescription, tagTypes } from '..';
|
||||
import { api, buildV2Url, LIST_TAG } from '..';
|
||||
|
||||
type UpdateMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
body: MainModelConfig;
|
||||
type UpdateModelArg = {
|
||||
key: NonNullable<operations['update_model_record']['parameters']['path']['key']>;
|
||||
body: NonNullable<operations['update_model_record']['requestBody']['content']['application/json']>;
|
||||
};
|
||||
|
||||
type UpdateLoRAModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
body: LoRAModelConfig;
|
||||
};
|
||||
|
||||
type UpdateMainModelResponse =
|
||||
paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||
|
||||
type UpdateLoRAModelResponse = UpdateMainModelResponse;
|
||||
|
||||
type DeleteMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
model_type: ModelType;
|
||||
key: string;
|
||||
};
|
||||
|
||||
type DeleteMainModelResponse = void;
|
||||
|
||||
type DeleteLoRAModelArg = DeleteMainModelArg;
|
||||
|
||||
type DeleteLoRAModelResponse = void;
|
||||
|
||||
type ConvertMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
@ -59,36 +41,40 @@ type ConvertMainModelArg = {
|
||||
};
|
||||
|
||||
type ConvertMainModelResponse =
|
||||
paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json'];
|
||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type MergeMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
body: MergeModelConfig;
|
||||
};
|
||||
|
||||
type MergeMainModelResponse =
|
||||
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
|
||||
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ImportMainModelArg = {
|
||||
body: ImportModelConfig;
|
||||
source: NonNullable<operations['heuristic_import_model']['parameters']['query']['source']>;
|
||||
access_token?: operations['heuristic_import_model']['parameters']['query']['access_token'];
|
||||
config: NonNullable<operations['heuristic_import_model']['requestBody']['content']['application/json']>;
|
||||
};
|
||||
|
||||
type ImportMainModelResponse =
|
||||
paths['/api/v1/models/import']['post']['responses']['201']['content']['application/json'];
|
||||
paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
type ListImportModelsResponse =
|
||||
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type AddMainModelArg = {
|
||||
body: MainModelConfig;
|
||||
};
|
||||
|
||||
type AddMainModelResponse = paths['/api/v1/models/add']['post']['responses']['201']['content']['application/json'];
|
||||
type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
type SyncModelsResponse = paths['/api/v1/models/sync']['post']['responses']['201']['content']['application/json'];
|
||||
type SyncModelsResponse = paths['/api/v2/models/sync']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
export type SearchFolderResponse =
|
||||
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
|
||||
paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type CheckpointConfigsResponse =
|
||||
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
|
||||
paths['/api/v2/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
|
||||
|
||||
@ -179,10 +165,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||
}),
|
||||
updateMainModels: build.mutation<UpdateMainModelResponse, UpdateMainModelArg>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
||||
query: ({ key, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`${base_model}/main/${model_name}`),
|
||||
url: buildModelsUrl(`i/${key}`),
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
@ -190,11 +176,12 @@ export const modelsApi = api.injectEndpoints({
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
|
||||
query: ({ body }) => {
|
||||
query: ({ source, config, access_token }) => {
|
||||
return {
|
||||
url: buildModelsUrl('import'),
|
||||
url: buildModelsUrl('heuristic_import'),
|
||||
params: { source, access_token },
|
||||
method: 'POST',
|
||||
body: body,
|
||||
body: config,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
@ -209,10 +196,10 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
deleteMainModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||
query: ({ base_model, model_name, model_type }) => {
|
||||
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||
query: ({ key }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`${base_model}/${model_type}/${model_name}`),
|
||||
url: buildModelsUrl(`i/${key}`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
@ -264,25 +251,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
|
||||
}),
|
||||
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||
}),
|
||||
deleteLoRAModels: build.mutation<DeleteLoRAModelResponse, DeleteLoRAModelArg>({
|
||||
query: ({ base_model, model_name }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||
}),
|
||||
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
|
||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
|
||||
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
|
||||
@ -316,6 +284,13 @@ export const modelsApi = api.injectEndpoints({
|
||||
};
|
||||
},
|
||||
}),
|
||||
getModelImports: build.query<ListImportModelsResponse, void>({
|
||||
query: (arg) => {
|
||||
return {
|
||||
url: buildModelsUrl(`import`),
|
||||
};
|
||||
},
|
||||
}),
|
||||
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
@ -335,15 +310,14 @@ export const {
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
useUpdateMainModelsMutation,
|
||||
useDeleteMainModelsMutation,
|
||||
useDeleteModelsMutation,
|
||||
useUpdateModelsMutation,
|
||||
useImportMainModelsMutation,
|
||||
useAddMainModelsMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useMergeMainModelsMutation,
|
||||
useDeleteLoRAModelsMutation,
|
||||
useUpdateLoRAModelsMutation,
|
||||
useSyncModelsMutation,
|
||||
useGetModelsInFolderQuery,
|
||||
useGetCheckpointConfigsQuery,
|
||||
useGetModelImportsQuery,
|
||||
} = modelsApi;
|
||||
|
Loading…
Reference in New Issue
Block a user