get old UI working somewhat with new endpoints

This commit is contained in:
Mary Hipp 2024-02-20 10:03:10 -05:00 committed by psychedelicious
parent 09295ae43b
commit bdc2b8069b
10 changed files with 120 additions and 155 deletions

View File

@ -1,15 +1,18 @@
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import AdvancedAddModels from './AdvancedAddModels';
import SimpleAddModels from './SimpleAddModels';
import { useGetModelImportsQuery } from '../../../../services/api/endpoints/models';
const AddModels = () => {
const { t } = useTranslation();
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>('simple');
const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []);
const handleAddModelAdvanced = useCallback(() => setAddModelMode('advanced'), []);
const { data } = useGetModelImportsQuery({});
console.log({ data });
return (
<Flex flexDirection="column" width="100%" overflow="scroll" maxHeight={window.innerHeight - 250} gap={4}>
<ButtonGroup>
@ -24,6 +27,7 @@ const AddModels = () => {
{addModelMode === 'simple' && <SimpleAddModels />}
{addModelMode === 'advanced' && <AdvancedAddModels />}
</Flex>
<Flex>{data?.map((model) => <Text>{model.status}</Text>)}</Flex>
</Flex>
);
};

View File

@ -36,11 +36,10 @@ const SimpleAddModels = () => {
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
const importModelResponseBody = {
location: values.location,
prediction_type: values.prediction_type === 'none' ? undefined : values.prediction_type,
config: values.prediction_type === 'none' ? undefined : values.prediction_type,
};
importMainModel({ body: importModelResponseBody })
importMainModel({ source: values.location, config: importModelResponseBody })
.unwrap()
.then((_) => {
dispatch(

View File

@ -2,13 +2,13 @@ import { Flex, Text } from '@invoke-ai/ui-library';
import { memo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
import { DiffusersModelConfig, LoRAConfig, MainModelConfig } from '../../../services/api/types';
const ModelManagerPanel = () => {
const [selectedModelId, setSelectedModelId] = useState<string>();
@ -41,16 +41,16 @@ const ModelEdit = (props: ModelEditProps) => {
const { t } = useTranslation();
const { model } = props;
if (model?.model_format === 'checkpoint') {
return <CheckpointModelEdit key={model.id} model={model} />;
if (model?.format === 'checkpoint') {
return <CheckpointModelEdit key={model.key} model={model} />;
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model as DiffusersModelConfig} />;
if (model?.format === 'diffusers') {
return <DiffusersModelEdit key={model.key} model={model as DiffusersModelConfig} />;
}
if (model?.model_type === 'lora') {
return <LoRAModelEdit key={model.id} model={model} />;
if (model?.type === 'lora') {
return <LoRAModelEdit key={model.key} model={model} />;
}
return (

View File

@ -21,11 +21,9 @@ import { memo, useCallback, useEffect, useState } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { CheckpointModelConfig } from 'services/api/endpoints/models';
import { useGetCheckpointConfigsQuery, useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
import ModelConvert from './ModelConvert';
import { CheckpointModelConfig } from '../../../../services/api/types';
type CheckpointModelEditProps = {
model: CheckpointModelConfig;
@ -34,7 +32,7 @@ type CheckpointModelEditProps = {
const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const [updateModel, { isLoading }] = useUpdateModelsMutation();
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
@ -56,12 +54,12 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
reset,
} = useForm<CheckpointModelConfig>({
defaultValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
name: model.name ? model.name : '',
base: model.base,
type: 'main',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: 'checkpoint',
format: 'checkpoint',
vae: model.vae ? model.vae : '',
config: model.config ? model.config : '',
variant: model.variant,
@ -74,11 +72,10 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
(values) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
key: model.key,
body: values,
};
updateMainModel(responseBody)
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as CheckpointModelConfig, { keepDefaultValues: true });
@ -103,7 +100,7 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
);
});
},
[dispatch, model.base_model, model.model_name, reset, t, updateMainModel]
[dispatch, model.key, reset, t, updateModel]
);
return (
@ -111,13 +108,13 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
<Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{model.model_name}
{model.name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')}
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
</Text>
</Flex>
{![''].includes(model.base_model) ? (
{![''].includes(model.base) ? (
<ModelConvert model={model} />
) : (
<Badge p={2} borderRadius={4} bg="error.400">
@ -130,20 +127,20 @@ const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
<Flex flexDirection="column" maxHeight={window.innerHeight - 270} overflowY="scroll">
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<FormControl isInvalid={Boolean(errors.model_name)}>
<FormControl isInvalid={Boolean(errors.name)}>
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('model_name', {
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<BaseModelSelect<CheckpointModelConfig> control={control} name="base_model" />
<BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -9,9 +9,8 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { DiffusersModelConfig } from 'services/api/endpoints/models';
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
import type { DiffusersModelConfig } from 'services/api/types';
import { useUpdateModelsMutation } from '../../../../services/api/endpoints/models';
type DiffusersModelEditProps = {
model: DiffusersModelConfig;
@ -20,7 +19,7 @@ type DiffusersModelEditProps = {
const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const [updateModel, { isLoading }] = useUpdateModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -33,12 +32,12 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
reset,
} = useForm<DiffusersModelConfig>({
defaultValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
name: model.name ? model.name : '',
base: model.base,
type: 'main',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: 'diffusers',
format: 'diffusers',
vae: model.vae ? model.vae : '',
variant: model.variant,
},
@ -48,12 +47,11 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
(values) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
key: model.key,
body: values,
};
updateMainModel(responseBody)
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as DiffusersModelConfig, { keepDefaultValues: true });
@ -78,37 +76,37 @@ const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
);
});
},
[dispatch, model.base_model, model.model_name, reset, t, updateMainModel]
[dispatch, model.key, reset, t, updateModel]
);
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{model.model_name}
{model.name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')}
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
</Text>
</Flex>
<Divider />
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<FormControl isInvalid={Boolean(errors.model_name)}>
<FormControl isInvalid={Boolean(errors.name)}>
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('model_name', {
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<BaseModelSelect<DiffusersModelConfig> control={control} name="base_model" />
<BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -8,7 +8,6 @@ import { memo, useCallback } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useUpdateLoRAModelsMutation } from 'services/api/endpoints/models';
import type { LoRAModelConfig } from 'services/api/types';
type LoRAModelEditProps = {
@ -18,7 +17,7 @@ type LoRAModelEditProps = {
const LoRAModelEdit = (props: LoRAModelEditProps) => {
const { model } = props;
const [updateLoRAModel, { isLoading }] = useUpdateLoRAModelsMutation();
const [updateModel, { isLoading }] = useUpdateModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -31,12 +30,12 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
reset,
} = useForm<LoRAModelConfig>({
defaultValues: {
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'lora',
name: model.name ? model.name : '',
base: model.base,
type: 'lora',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: model.model_format,
format: model.format,
},
mode: 'onChange',
});
@ -44,12 +43,11 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
(values) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
key: model.key,
body: values,
};
updateLoRAModel(responseBody)
updateModel(responseBody)
.unwrap()
.then((payload) => {
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
@ -74,17 +72,17 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
);
});
},
[dispatch, model.base_model, model.model_name, reset, t, updateLoRAModel]
[dispatch, model.key, reset, t, updateModel]
);
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{model.model_name}
{model.name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[model.base_model]} {t('modelManager.model')} {LORA_MODEL_FORMAT_MAP[model.model_format]}{' '}
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')} {LORA_MODEL_FORMAT_MAP[model.format]}{' '}
{t('common.format')}
</Text>
</Flex>
@ -92,20 +90,20 @@ const LoRAModelEdit = (props: LoRAModelEditProps) => {
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<FormControl isInvalid={Boolean(errors.model_name)}>
<FormControl isInvalid={Boolean(errors.name)}>
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('model_name', {
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
})}
/>
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<FormControl>
<FormLabel>{t('modelManager.description')}</FormLabel>
<Input {...register('description')} />
</FormControl>
<BaseModelSelect<LoRAModelConfig> control={control} name="base_model" />
<BaseModelSelect<LoRAModelConfig> control={control} name="base" />
<FormControl isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>

View File

@ -54,8 +54,8 @@ const ModelConvert = (props: ModelConvertProps) => {
const modelConvertHandler = useCallback(() => {
const queryArg = {
base_model: model.base_model,
model_name: model.model_name,
base_model: model.base,
model_name: model.name,
convert_dest_directory: saveLocation === 'Custom' ? customSaveLocation : undefined,
};
@ -74,7 +74,7 @@ const ModelConvert = (props: ModelConvertProps) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.convertingModelBegin')}: ${model.model_name}`,
title: `${t('modelManager.convertingModelBegin')}: ${model.name}`,
status: 'info',
})
)
@ -86,7 +86,7 @@ const ModelConvert = (props: ModelConvertProps) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConverted')}: ${model.model_name}`,
title: `${t('modelManager.modelConverted')}: ${model.name}`,
status: 'success',
})
)
@ -96,13 +96,13 @@ const ModelConvert = (props: ModelConvertProps) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConversionFailed')}: ${model.model_name}`,
title: `${t('modelManager.modelConversionFailed')}: ${model.name}`,
status: 'error',
})
)
);
});
}, [convertModel, customSaveLocation, dispatch, model.base_model, model.model_name, saveLocation, t]);
}, [convertModel, customSaveLocation, dispatch, model.base, model.name, saveLocation, t]);
return (
<>
@ -116,7 +116,7 @@ const ModelConvert = (props: ModelConvertProps) => {
🧨 {t('modelManager.convertToDiffusers')}
</Button>
<ConfirmationAlertDialog
title={`${t('modelManager.convert')} ${model.model_name}`}
title={`${t('modelManager.convert')} ${model.name}`}
acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`}

View File

@ -5,10 +5,11 @@ import type { ChangeEvent, PropsWithChildren } from 'react';
import { memo, useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { ALL_BASE_MODELS } from 'services/api/constants';
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
// import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
type ModelListProps = {
selectedModelId: string | undefined;
@ -177,9 +178,9 @@ const ModelListWrapper = memo((props: ModelListWrapperProps) => {
</Text>
{modelList.map((model) => (
<ModelListItem
key={model.id}
key={model.key}
model={model}
isSelected={selected.selectedModelId === model.id}
isSelected={selected.selectedModelId === model.key}
setSelectedModelId={selected.setSelectedModelId}
/>
))}

View File

@ -15,8 +15,8 @@ import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
import { useDeleteLoRAModelsMutation, useDeleteMainModelsMutation } from 'services/api/endpoints/models';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import { LoRAConfig, MainModelConfig } from '../../../../services/api/types';
type ModelListItemProps = {
model: MainModelConfig | LoRAConfig;
@ -27,29 +27,23 @@ type ModelListItemProps = {
const ModelListItem = (props: ModelListItemProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [deleteMainModel] = useDeleteMainModelsMutation();
const [deleteLoRAModel] = useDeleteLoRAModelsMutation();
const [deleteModel] = useDeleteModelsMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const { model, isSelected, setSelectedModelId } = props;
const handleSelectModel = useCallback(() => {
setSelectedModelId(model.id);
}, [model.id, setSelectedModelId]);
setSelectedModelId(model.key);
}, [model.key, setSelectedModelId]);
const handleModelDelete = useCallback(() => {
const method = {
main: deleteMainModel,
lora: deleteLoRAModel,
}[model.model_type];
method(model)
deleteModel({ key: model.key })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleted')}: ${model.model_name}`,
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
status: 'success',
})
)
@ -60,7 +54,7 @@ const ModelListItem = (props: ModelListItemProps) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelDeleteFailed')}: ${model.model_name}`,
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
status: 'error',
})
)
@ -68,7 +62,7 @@ const ModelListItem = (props: ModelListItemProps) => {
}
});
setSelectedModelId(undefined);
}, [deleteMainModel, deleteLoRAModel, model, setSelectedModelId, dispatch, t]);
}, [deleteModel, model, setSelectedModelId, dispatch, t]);
return (
<Flex gap={2} alignItems="center" w="full">
@ -85,10 +79,10 @@ const ModelListItem = (props: ModelListItemProps) => {
>
<Flex gap={4} alignItems="center">
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
{MODEL_TYPE_SHORT_MAP[model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP]}
{MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]}
</Badge>
<Tooltip label={model.description} placement="bottom">
<Text>{model.model_name}</Text>
<Text>{model.name}</Text>
</Tooltip>
</Flex>
</Flex>

View File

@ -7,12 +7,10 @@ import type {
AnyModelConfig,
BaseModelType,
ControlNetModelConfig,
ImportModelConfig,
IPAdapterModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
ModelType,
T2IAdapterModelConfig,
TextualInversionModelConfig,
VAEModelConfig,
@ -21,37 +19,21 @@ import type {
import type { ApiTagDescription, tagTypes } from '..';
import { api, buildV2Url, LIST_TAG } from '..';
type UpdateMainModelArg = {
base_model: BaseModelType;
model_name: string;
body: MainModelConfig;
type UpdateModelArg = {
key: NonNullable<operations['update_model_record']['parameters']['path']['key']>;
body: NonNullable<operations['update_model_record']['requestBody']['content']['application/json']>;
};
type UpdateLoRAModelArg = {
base_model: BaseModelType;
model_name: string;
body: LoRAModelConfig;
};
type UpdateMainModelResponse =
paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
type UpdateLoRAModelResponse = UpdateMainModelResponse;
type DeleteMainModelArg = {
base_model: BaseModelType;
model_name: string;
model_type: ModelType;
key: string;
};
type DeleteMainModelResponse = void;
type DeleteLoRAModelArg = DeleteMainModelArg;
type DeleteLoRAModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
@ -59,36 +41,40 @@ type ConvertMainModelArg = {
};
type ConvertMainModelResponse =
paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json'];
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
type MergeMainModelArg = {
base_model: BaseModelType;
body: MergeModelConfig;
};
type MergeMainModelResponse =
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = {
body: ImportModelConfig;
source: NonNullable<operations['heuristic_import_model']['parameters']['query']['source']>;
access_token?: operations['heuristic_import_model']['parameters']['query']['access_token'];
config: NonNullable<operations['heuristic_import_model']['requestBody']['content']['application/json']>;
};
type ImportMainModelResponse =
paths['/api/v1/models/import']['post']['responses']['201']['content']['application/json'];
paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
type AddMainModelArg = {
body: MainModelConfig;
};
type AddMainModelResponse = paths['/api/v1/models/add']['post']['responses']['201']['content']['application/json'];
type AddMainModelResponse = paths['/api/v2/models/add']['post']['responses']['201']['content']['application/json'];
type SyncModelsResponse = paths['/api/v1/models/sync']['post']['responses']['201']['content']['application/json'];
type SyncModelsResponse = paths['/api/v2/models/sync']['post']['responses']['201']['content']['application/json'];
export type SearchFolderResponse =
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
paths['/api/v2/models/search']['get']['responses']['200']['content']['application/json'];
type CheckpointConfigsResponse =
paths['/api/v1/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
paths['/api/v2/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
type SearchFolderArg = operations['search_for_models']['parameters']['query'];
@ -179,10 +165,10 @@ export const modelsApi = api.injectEndpoints({
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
}),
updateMainModels: build.mutation<UpdateMainModelResponse, UpdateMainModelArg>({
query: ({ base_model, model_name, body }) => {
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
query: ({ key, body }) => {
return {
url: buildModelsUrl(`${base_model}/main/${model_name}`),
url: buildModelsUrl(`i/${key}`),
method: 'PATCH',
body: body,
};
@ -190,11 +176,12 @@ export const modelsApi = api.injectEndpoints({
invalidatesTags: ['Model'],
}),
importMainModels: build.mutation<ImportMainModelResponse, ImportMainModelArg>({
query: ({ body }) => {
query: ({ source, config, access_token }) => {
return {
url: buildModelsUrl('import'),
url: buildModelsUrl('heuristic_import'),
params: { source, access_token },
method: 'POST',
body: body,
body: config,
};
},
invalidatesTags: ['Model'],
@ -209,10 +196,10 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
deleteMainModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
query: ({ base_model, model_name, model_type }) => {
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
query: ({ key }) => {
return {
url: buildModelsUrl(`${base_model}/${model_type}/${model_name}`),
url: buildModelsUrl(`i/${key}`),
method: 'DELETE',
};
},
@ -264,25 +251,6 @@ export const modelsApi = api.injectEndpoints({
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
transformResponse: buildTransformResponse<LoRAModelConfig>(loraModelsAdapter),
}),
updateLoRAModels: build.mutation<UpdateLoRAModelResponse, UpdateLoRAModelArg>({
query: ({ base_model, model_name, body }) => {
return {
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
deleteLoRAModels: build.mutation<DeleteLoRAModelResponse, DeleteLoRAModelArg>({
query: ({ base_model, model_name }) => {
return {
url: buildModelsUrl(`${base_model}/lora/${model_name}`),
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
}),
getControlNetModels: build.query<EntityState<ControlNetModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'controlnet' } }),
providesTags: buildProvidesTags<ControlNetModelConfig>('ControlNetModel'),
@ -316,6 +284,13 @@ export const modelsApi = api.injectEndpoints({
};
},
}),
getModelImports: build.query<ListImportModelsResponse, void>({
query: (arg) => {
return {
url: buildModelsUrl(`import`),
};
},
}),
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
query: () => {
return {
@ -335,15 +310,14 @@ export const {
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useUpdateMainModelsMutation,
useDeleteMainModelsMutation,
useDeleteModelsMutation,
useUpdateModelsMutation,
useImportMainModelsMutation,
useAddMainModelsMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
useDeleteLoRAModelsMutation,
useUpdateLoRAModelsMutation,
useSyncModelsMutation,
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,
useGetModelImportsQuery,
} = modelsApi;