get old UI working somewhat with new endpoints

This commit is contained in:
Mary Hipp 2024-02-20 10:03:10 -05:00 committed by psychedelicious
parent 09295ae43b
commit bdc2b8069b
10 changed files with 120 additions and 155 deletions

View File

@ -1,15 +1,18 @@
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library'; import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
import { memo, useCallback, useState } from 'react'; import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import AdvancedAddModels from './AdvancedAddModels'; import AdvancedAddModels from './AdvancedAddModels';
import SimpleAddModels from './SimpleAddModels'; import SimpleAddModels from './SimpleAddModels';
import { useGetModelImportsQuery } from '../../../../services/api/endpoints/models';
const AddModels = () => { const AddModels = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>('simple'); const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>('simple');
const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []); const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []);
const handleAddModelAdvanced = useCallback(() => setAddModelMode('advanced'), []); const handleAddModelAdvanced = useCallback(() => setAddModelMode('advanced'), []);
const { data } = useGetModelImportsQuery({});
console.log({ data });
return ( return (
<Flex flexDirection="column" width="100%" overflow="scroll" maxHeight={window.innerHeight - 250} gap={4}> <Flex flexDirection="column" width="100%" overflow="scroll" maxHeight={window.innerHeight - 250} gap={4}>
<ButtonGroup> <ButtonGroup>
@ -24,6 +27,7 @@ const AddModels = () => {
{addModelMode === 'simple' && <SimpleAddModels />} {addModelMode === 'simple' && <SimpleAddModels />}
{addModelMode === 'advanced' && <AdvancedAddModels />} {addModelMode === 'advanced' && <AdvancedAddModels />}
</Flex> </Flex>
<Flex>{data?.map((model) => <Text>{model.status}</Text>)}</Flex>
</Flex> </Flex>
); );
}; };

View File

@ -36,11 +36,10 @@ const SimpleAddModels = () => {
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => { const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
const importModelResponseBody = { const importModelResponseBody = {
location: values.location, config: values.prediction_type === 'none' ? undefined : values.prediction_type,
prediction_type: values.prediction_type === 'none' ? undefined : values.prediction_type,
}; };
importMainModel({ body: importModelResponseBody }) importMainModel({ source: values.location, config: importModelResponseBody })
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(

View File

@ -2,13 +2,13 @@ import { Flex, Text } from '@invoke-ai/ui-library';
import { memo, useState } from 'react'; import { memo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants'; import { ALL_BASE_MODELS } from 'services/api/constants';
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit'; import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
import { DiffusersModelConfig, LoRAConfig, MainModelConfig } from '../../../services/api/types';
const ModelManagerPanel = () => { const ModelManagerPanel = () => {
const [selectedModelId, setSelectedModelId] = useState<string>(); const [selectedModelId, setSelectedModelId] = useState<string>();
@ -41,16 +41,16 @@ const ModelEdit = (props: ModelEditProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const { model } = props; const { model } = props;
if (model?.model_format === 'checkpoint') { if (model?.format === 'checkpoint') {
return <CheckpointModelEdit key={model.id} model={model} />; return <CheckpointModelEdit key={model.key} model={model} />;
} }
if (model?.model_format === 'diffusers') { if (model?.format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model as DiffusersModelConfig} />; return <DiffusersModelEdit key={model.key} model={model as DiffusersModelConfig} />;
} }
if (model?.model_type === 'lora') { if (model?.type === 'lora') {
return <LoRAModelEdit key={model.id} model={model} />; return <LoRAModelEdit key={model.key} model={model} />;
} }
return ( return (

View File

@ -21,11 +21,9 @@ import { memo, useCallback, useEffect, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { CheckpointModelConfig } from 'services/api/endpoints/models'; import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';
import { CheckpointModelConfig } from '../../../../services/api/types';
type CheckpointModelEditProps = { type CheckpointModelEditProps = {
model: CheckpointModelConfig; model: CheckpointModelConfig;
@ -34,7 +32,7 @@ type CheckpointModelEditProps = {
const CheckpointModelEdit = (props: CheckpointModelEditProps) => { const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
const { model } = props; const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); const [updateModel, { isLoading }] = useUpdateModelsMutation();
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery(); const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false); const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
@ -56,12 +54,12 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
reset, reset,
} = useForm<CheckpointModelConfig>({ } = useForm<CheckpointModelConfig>({
defaultValues: { defaultValues: {
model_name: model.model_name ? model.model_name : '', name: model.name ? model.name : '',
base_model: model.base_model, base: model.base,
model_type: 'main', type: 'main',
path: model.path ? model.path : '', path: model.path ? model.path : '',
description: model.description ? model.description : '', description: model.description ? model.description : '',
model_format: 'checkpoint', format: 'checkpoint',
vae: model.vae ? model.vae : '', vae: model.vae ? model.vae : '',
config: model.config ? model.config : '', config: model.config ? model.config : '',
variant: model.variant, variant: model.variant,
@ -74,11 +72,10 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>( const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
(values) => { (values) => {
const responseBody = { const responseBody = {
base_model: model.base_model, key: model.key,
model_name: model.model_name,
body: values, body: values,
}; };
updateMainModel(responseBody) updateModel(responseBody)
.unwrap() .unwrap()
.then((payload) => { .then((payload) => {
reset(payload as CheckpointModelConfig, { keepDefaultValues: true }); reset(payload as CheckpointModelConfig, { keepDefaultValues: true });
@ -103,7 +100,7 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
); );
}); });
}, },
[dispatch, model.base_model, model.model_name, reset, t, updateMainModel] [dispatch, model.key, reset, t, updateModel]
); );
return ( return (
@ -111,13 +108,13 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column"> <Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{model.model_name} {model.name}
</Text> </Text>
<Text fontSize="sm" color="base.400"> <Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')} {MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
</Text> </Text>
</Flex> </Flex>
{![''].includes(model.base_model) ? ( {![''].includes(model.base) ? (
<ModelConvert model={model} /> <ModelConvert model={model} />
) : ( ) : (
<Badge p={2} borderRadius={4} bg="error.400"> <Badge p={2} borderRadius={4} bg="error.400">
@ -130,20 +127,20 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
<Flex flexDirection="column" maxHeight={window.innerHeight - 270} overflowY="scroll"> <Flex flexDirection="column" maxHeight={window.innerHeight - 270} overflowY="scroll">
<form onSubmit={handleSubmit(onSubmit)}> <form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<FormControl isInvalid={Boolean(errors.model_name)}> <FormControl isInvalid={Boolean(errors.name)}>
<FormLabel>{t('modelManager.name')}</FormLabel> <FormLabel>{t('modelManager.name')}</FormLabel>
<Input <Input
{...register('model_name', { {...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters', validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})} })}
/> />
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>} {errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl> </FormControl>
<FormControl> <FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel> <FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} /> <Input {...register('description')} />
</FormControl> </FormControl>
<BaseModelSelect<CheckpointModelConfig> control={control} name="base_model" /> <BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" /> <ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
<FormControl isInvalid={Boolean(errors.path)}> <FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel> <FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -9,9 +9,8 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { DiffusersModelConfig } from 'services/api/endpoints/models';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types'; import type { DiffusersModelConfig } from 'services/api/types';
import { useUpdateModelsMutation } from '../../../../services/api/endpoints/models';
type DiffusersModelEditProps = { type DiffusersModelEditProps = {
model: DiffusersModelConfig; model: DiffusersModelConfig;
@ -20,7 +19,7 @@ type DiffusersModelEditProps = {
const DiffusersModelEdit = (props: DiffusersModelEditProps) => { const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
const { model } = props; const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation(); const [updateModel, { isLoading }] = useUpdateModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -33,12 +32,12 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
reset, reset,
} = useForm<DiffusersModelConfig>({ } = useForm<DiffusersModelConfig>({
defaultValues: { defaultValues: {
model_name: model.model_name ? model.model_name : '', name: model.name ? model.name : '',
base_model: model.base_model, base: model.base,
model_type: 'main', type: 'main',
path: model.path ? model.path : '', path: model.path ? model.path : '',
description: model.description ? model.description : '', description: model.description ? model.description : '',
model_format: 'diffusers', format: 'diffusers',
vae: model.vae ? model.vae : '', vae: model.vae ? model.vae : '',
variant: model.variant, variant: model.variant,
}, },
@ -48,12 +47,11 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>( const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => { (values) => {
const responseBody = { const responseBody = {
base_model: model.base_model, key: model.key,
model_name: model.model_name,
body: values, body: values,
}; };
updateMainModel(responseBody) updateModel(responseBody)
.unwrap() .unwrap()
.then((payload) => { .then((payload) => {
reset(payload as DiffusersModelConfig, { keepDefaultValues: true }); reset(payload as DiffusersModelConfig, { keepDefaultValues: true });
@ -78,37 +76,37 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
); );
}); });
}, },
[dispatch, model.base_model, model.model_name, reset, t, updateMainModel] [dispatch, model.key, reset, t, updateModel]
); );
return ( return (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column"> <Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{model.model_name} {model.name}
</Text> </Text>
<Text fontSize="sm" color="base.400"> <Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')} {MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
</Text> </Text>
</Flex> </Flex>
<Divider /> <Divider />
<form onSubmit={handleSubmit(onSubmit)}> <form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<FormControl isInvalid={Boolean(errors.model_name)}> <FormControl isInvalid={Boolean(errors.name)}>
<FormLabel>{t('modelManager.name')}</FormLabel> <FormLabel>{t('modelManager.name')}</FormLabel>
<Input <Input
{...register('model_name', { {...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters', validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})} })}
/> />
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>} {errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl> </FormControl>
<FormControl> <FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel> <FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} /> <Input {...register('description')} />
</FormControl> </FormControl>
<BaseModelSelect<DiffusersModelConfig> control={control} name="base_model" /> <BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" /> <ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
<FormControl isInvalid={Boolean(errors.path)}> <FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel> <FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -8,7 +8,6 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form'; import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
import type { LoRAModelConfig } from 'services/api/types'; import type { LoRAModelConfig } from 'services/api/types';
type LoRAModelEditProps = { type LoRAModelEditProps = {
@ -18,7 +17,7 @@ type LoRAModelEditProps = {
const LoRAModelEdit = (props: LoRAModelEditProps) => { const LoRAModelEdit = (props: LoRAModelEditProps) => {
const { model } = props; const { model } = props;
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation(); const [updateModel, { isLoading }] = useUpdateModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -31,12 +30,12 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
reset, reset,
} = useForm<LoRAModelConfig>({ } = useForm<LoRAModelConfig>({
defaultValues: { defaultValues: {
model_name: model.model_name ? model.model_name : '', name: model.name ? model.name : '',
base_model: model.base_model, base: model.base,
model_type: 'lora', type: 'lora',
path: model.path ? model.path : '', path: model.path ? model.path : '',
description: model.description ? model.description : '', description: model.description ? model.description : '',
model_format: model.model_format, format: model.format,
}, },
mode: 'onChange', mode: 'onChange',
}); });
@ -44,12 +43,11 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>( const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
(values) => { (values) => {
const responseBody = { const responseBody = {
base_model: model.base_model, key: model.key,
model_name: model.model_name,
body: values, body: values,
}; };
updateLoRAModel(responseBody) updateModel(responseBody)
.unwrap() .unwrap()
.then((payload) => { .then((payload) => {
reset(payload as LoRAModelConfig, { keepDefaultValues: true }); reset(payload as LoRAModelConfig, { keepDefaultValues: true });
@ -74,17 +72,17 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
); );
}); });
}, },
[dispatch, model.base_model, model.model_name, reset, t, updateLoRAModel] [dispatch, model.key, reset, t, updateModel]
); );
return ( return (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column"> <Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{model.model_name} {model.name}
</Text> </Text>
<Text fontSize="sm" color="base.400"> <Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')} {LORA_MODEL_FORMAT_MAP[model.model_format]}{' '} {MODEL_TYPE_MAP[model.base]} {t('modelManager.model')} {LORA_MODEL_FORMAT_MAP[model.format]}{' '}
{t('common.format')} {t('common.format')}
</Text> </Text>
</Flex> </Flex>
@ -92,20 +90,20 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
<form onSubmit={handleSubmit(onSubmit)}> <form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<FormControl isInvalid={Boolean(errors.model_name)}> <FormControl isInvalid={Boolean(errors.name)}>
<FormLabel>{t('modelManager.name')}</FormLabel> <FormLabel>{t('modelManager.name')}</FormLabel>
<Input <Input
{...register('model_name', { {...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters', validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})} })}
/> />
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>} {errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl> </FormControl>
<FormControl> <FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel> <FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} /> <Input {...register('description')} />
</FormControl> </FormControl>
<BaseModelSelect<LoRAModelConfig> control={control} name="base_model" /> <BaseModelSelect<LoRAModelConfig> control={control} name="base" />
<FormControl isInvalid={Boolean(errors.path)}> <FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel> <FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -54,8 +54,8 @@ const ModelConvert = (props: ModelConvertProps) => {
const modelConvertHandler = useCallback(() => { const modelConvertHandler = useCallback(() => {
const queryArg = { const queryArg = {
base_model: model.base_model, base_model: model.base,
model_name: model.model_name, model_name: model.name,
convert_dest_directory: saveLocation === 'Custom' ? customSaveLocation : undefined, convert_dest_directory: saveLocation === 'Custom' ? customSaveLocation : undefined,
}; };
@ -74,7 +74,7 @@ const ModelConvert = (props: ModelConvertProps) => {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: `${t('modelManager.convertingModelBegin')}: ${model.model_name}`, title: `${t('modelManager.convertingModelBegin')}: ${model.name}`,
status: 'info', status: 'info',
}) })
) )
@ -86,7 +86,7 @@ const ModelConvert = (props: ModelConvertProps) => {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: `${t('modelManager.modelConverted')}: ${model.model_name}`, title: `${t('modelManager.modelConverted')}: ${model.name}`,
status: 'success', status: 'success',
}) })
) )
@ -96,13 +96,13 @@ const ModelConvert = (props: ModelConvertProps) => {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: `${t('modelManager.modelConversionFailed')}: ${model.model_name}`, title: `${t('modelManager.modelConversionFailed')}: ${model.name}`,
status: 'error', status: 'error',
}) })
) )
); );
}); });
}, [convertModel, customSaveLocation, dispatch, model.base_model, model.model_name, saveLocation, t]); }, [convertModel, customSaveLocation, dispatch, model.base, model.name, saveLocation, t]);
return ( return (
<> <>
@ -116,7 +116,7 @@ const ModelConvert = (props: ModelConvertProps) => {
🧨 {t('modelManager.convertToDiffusers')} 🧨 {t('modelManager.convertToDiffusers')}
</Button> </Button>
<ConfirmationAlertDialog <ConfirmationAlertDialog
title={`${t('modelManager.convert')} ${model.model_name}`} title={`${t('modelManager.convert')} ${model.name}`}
acceptCallback={modelConvertHandler} acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler} cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`} acceptButtonText={`${t('modelManager.convert')}`}

View File

@ -5,10 +5,11 @@ import type { ChangeEvent, PropsWithChildren } from 'react';
import { memo, useCallback, useState } from 'react'; import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants'; import { ALL_BASE_MODELS } from 'services/api/constants';
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; // import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
type ModelListProps = { type ModelListProps = {
selectedModelId: string | undefined; selectedModelId: string | undefined;
@ -177,9 +178,9 @@ const ModelListWrapper = memo((props: ModelListWrapperProps) => {
</Text> </Text>
{modelList.map((model) => ( {modelList.map((model) => (
<ModelListItem <ModelListItem
key={model.id} key={model.key}
model={model} model={model}
isSelected={selected.selectedModelId === model.id} isSelected={selected.selectedModelId === model.key}
setSelectedModelId={selected.setSelectedModelId} setSelectedModelId={selected.setSelectedModelId}
/> />
))} ))}

View File

@ -15,8 +15,8 @@ import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi'; import { PiTrashSimpleBold } from 'react-icons/pi';
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models'; import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models'; import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
type ModelListItemProps = { type ModelListItemProps = {
model: MainModelConfig | LoRAConfig; model: MainModelConfig | LoRAConfig;
@ -27,29 +27,23 @@ type ModelListItemProps = {
const ModelListItem = (props: ModelListItemProps) => { const ModelListItem = (props: ModelListItemProps) => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [deleteMainModel] = useDeleteMainModelsMutation(); const [deleteModel] = useDeleteModelsMutation();
const [deleteLoRAModel] = useDeleteLoRAModelsMutation();
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
const { model, isSelected, setSelectedModelId } = props; const { model, isSelected, setSelectedModelId } = props;
const handleSelectModel = useCallback(() => { const handleSelectModel = useCallback(() => {
setSelectedModelId(model.id); setSelectedModelId(model.key);
}, [model.id, setSelectedModelId]); }, [model.key, setSelectedModelId]);
const handleModelDelete = useCallback(() => { const handleModelDelete = useCallback(() => {
const method = { deleteModel({ key: model.key })
main: deleteMainModel,
lora: deleteLoRAModel,
}[model.model_type];
method(model)
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: `${t('modelManager.modelDeleted')}: ${model.model_name}`, title: `${t('modelManager.modelDeleted')}: ${model.name}`,
status: 'success', status: 'success',
}) })
) )
@ -60,7 +54,7 @@ const ModelListItem = (props: ModelListItemProps) => {
dispatch( dispatch(
addToast( addToast(
makeToast({ makeToast({
title: `${t('modelManager.modelDeleteFailed')}: ${model.model_name}`, title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
status: 'error', status: 'error',
}) })
) )
@ -68,7 +62,7 @@ const ModelListItem = (props: ModelListItemProps) => {
} }
}); });
setSelectedModelId(undefined); setSelectedModelId(undefined);
}, [deleteMainModel, deleteLoRAModel, model, setSelectedModelId, dispatch, t]); }, [deleteModel, model, setSelectedModelId, dispatch, t]);
return ( return (
<Flex gap={2} alignItems="center" w="full"> <Flex gap={2} alignItems="center" w="full">
@ -85,10 +79,10 @@ const ModelListItem = (props: ModelListItemProps) => {
> >
<Flex gap={4} alignItems="center"> <Flex gap={4} alignItems="center">
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid"> <Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
{MODEL_TYPE_SHORT_MAP[model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP]} {MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]}
</Badge> </Badge>
<Tooltip label={model.description} placement="bottom"> <Tooltip label={model.description} placement="bottom">
<Text>{model.model_name}</Text> <Text>{model.name}</Text>
</Tooltip> </Tooltip>
</Flex> </Flex>
</Flex> </Flex>

View File

@ -7,12 +7,10 @@ import type {
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ControlNetModelConfig, ControlNetModelConfig,
ImportModelConfig,
IPAdapterModelConfig, IPAdapterModelConfig,
LoRAModelConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
MergeModelConfig, MergeModelConfig,
ModelType,
T2IAdapterModelConfig, T2IAdapterModelConfig,
TextualInversionModelConfig, TextualInversionModelConfig,
VAEModelConfig, VAEModelConfig,
@ -21,37 +19,21 @@ import type {
import type { ApiTagDescription, tagTypes } from '..'; import type { ApiTagDescription, tagTypes } from '..';
import { api, buildV2Url, LIST_TAG } from '..'; import { api, buildV2Url, LIST_TAG } from '..';
type UpdateMainModelArg = { type UpdateModelArg = {
base_model: BaseModelType; key: NonNullable<operations['update_model_record']['parameters']['path']['key']>;
model_name: string; body: NonNullable<operations['update_model_record']['requestBody']['content']['application/json']>;
body: MainModelConfig;
}; };
type UpdateLoRAModelArg = { type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
};
type UpdateMainModelResponse =
paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>; type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
type UpdateLoRAModelResponse = UpdateMainModelResponse;
type DeleteMainModelArg = { type DeleteMainModelArg = {
base_model: BaseModelType; key: string;
model_name: string;
model_type: ModelType;
}; };
type DeleteMainModelResponse = void; type DeleteMainModelResponse = void;
type DeleteLoRAModelArg = DeleteMainModelArg;
type DeleteLoRAModelResponse = void;
type ConvertMainModelArg = { type ConvertMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;
@ -59,36 +41,40 @@ type ConvertMainModelArg = {
}; };
type ConvertMainModelResponse = type ConvertMainModelResponse =
paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json']; paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
type MergeMainModelArg = { type MergeMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
body: MergeModelConfig; body: MergeModelConfig;
}; };
type MergeMainModelResponse = type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = { type ImportMainModelArg = {
body: ImportModelConfig; source: NonNullable<operations['heuristic_import_model']['parameters']['query']['source']>;
access_token?: operations['heuristic_import_model']['parameters']['query']['access_token'];
config: NonNullable<operations['heuristic_import_model']['requestBody']['content']['application/json']>;
}; };
type ImportMainModelResponse = type ImportMainModelResponse =
paths['/api/v1/models/import']['post']['responses']['201']['content']['application/json']; paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
type AddMainModelArg = { type AddMainModelArg = {
body: MainModelConfig; body: MainModelConfig;
}; };
type AddMainModelResponse = paths['/api/v1/models/add']['post']['responses']['201']['content']['application/json']; type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
type SyncModelsResponse = paths['/api/v1/models/sync']['post']['responses']['201']['content']['application/json']; type SyncModelsResponse = paths['/api/v2/models/sync']['post']['responses']['201']['content']['application/json'];
export type SearchFolderResponse = export type SearchFolderResponse =
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json'];
type CheckpointConfigsResponse = type CheckpointConfigsResponse =
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
type SearchFolderArg = operations['search_for_models']['parameters']['query']; type SearchFolderArg = operations['search_for_models']['parameters']['query'];
@ -179,10 +165,10 @@ export const modelsApi = api.injectEndpoints({
providesTags: buildProvidesTags<MainModelConfig>('MainModel'), providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter), transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
}), }),
updateMainModels: build.mutation<UpdateMainModelResponse, UpdateMainModelArg>({ updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
query: ({ base_model, model_name, body }) => { query: ({ key, body }) => {
return { return {
url: buildModelsUrl(`${base_model}/main/${model_name}`), url: buildModelsUrl(`i/${key}`),
method: 'PATCH', method: 'PATCH',
body: body, body: body,
}; };
@ -190,11 +176,12 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({ importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
query: ({ body }) => { query: ({ source, config, access_token }) => {
return { return {
url: buildModelsUrl('import'), url: buildModelsUrl('heuristic_import'),
params: { source, access_token },
method: 'POST', method: 'POST',
body: body, body: config,
}; };
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
@ -209,10 +196,10 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
deleteMainModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({ deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
query: ({ base_model, model_name, model_type }) => { query: ({ key }) => {
return { return {
url: buildModelsUrl(`${base_model}/${model_type}/${model_name}`), url: buildModelsUrl(`i/${key}`),
method: 'DELETE', method: 'DELETE',
}; };
}, },
@ -264,25 +251,6 @@ export const modelsApi = api.injectEndpoints({
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'), providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter), transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
}), }),
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
query: ({ base_model, model_name, body }) => {
return {
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
deleteLoRAModels: build.mutation<DeleteLoRAModelResponse, DeleteLoRAModelArg>({
query: ({ base_model, model_name }) => {
return {
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({ getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'), providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
@ -316,6 +284,13 @@ export const modelsApi = api.injectEndpoints({
}; };
}, },
}), }),
getModelImports: build.query<ListImportModelsResponse, void>({
query: (arg) => {
return {
url: buildModelsUrl(`import`),
};
},
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({ getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => { query: () => {
return { return {
@ -335,15 +310,14 @@ export const {
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery, useGetTextualInversionModelsQuery,
useGetVaeModelsQuery, useGetVaeModelsQuery,
useUpdateMainModelsMutation, useDeleteModelsMutation,
useDeleteMainModelsMutation, useUpdateModelsMutation,
useImportMainModelsMutation, useImportMainModelsMutation,
useAddMainModelsMutation, useAddMainModelsMutation,
useConvertMainModelsMutation, useConvertMainModelsMutation,
useMergeMainModelsMutation, useMergeMainModelsMutation,
useDeleteLoRAModelsMutation,
useUpdateLoRAModelsMutation,
useSyncModelsMutation, useSyncModelsMutation,
useGetModelsInFolderQuery, useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery, useGetCheckpointConfigsQuery,
useGetModelImportsQuery,
} = modelsApi; } = modelsApi;