mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Restore Update Model functionality
This commit is contained in:
parent
af239fa122
commit
afb46564e8
@ -342,6 +342,7 @@
|
||||
"safetensorModels": "SafeTensors",
|
||||
"modelAdded": "Model Added",
|
||||
"modelUpdated": "Model Updated",
|
||||
"modelUpdateFailed": "Model Update Failed",
|
||||
"modelEntryDeleted": "Model Entry Deleted",
|
||||
"cannotUseSpaces": "Cannot Use Spaces",
|
||||
"addNew": "Add New",
|
||||
|
@ -11,7 +11,11 @@ import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||
import { S } from 'services/api/types';
|
||||
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { components } from 'services/api/schema';
|
||||
import ModelConvert from './ModelConvert';
|
||||
|
||||
const baseModelSelectData = [
|
||||
@ -25,13 +29,13 @@ const variantSelectData = [
|
||||
{ value: 'depth', label: 'Depth' },
|
||||
];
|
||||
|
||||
export type CheckpointModel =
|
||||
| S<'StableDiffusion1ModelCheckpointConfig'>
|
||||
| S<'StableDiffusion2ModelCheckpointConfig'>;
|
||||
export type CheckpointModelConfig =
|
||||
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
|
||||
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
|
||||
|
||||
type CheckpointModelEditProps = {
|
||||
modelToEdit: string;
|
||||
retrievedModel: CheckpointModel;
|
||||
retrievedModel: CheckpointModelConfig;
|
||||
};
|
||||
|
||||
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||
@ -41,25 +45,52 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||
|
||||
const { modelToEdit, retrievedModel } = props;
|
||||
|
||||
const [updateMainModel, { error }] = useUpdateMainModelsMutation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const checkpointEditForm = useForm({
|
||||
const checkpointEditForm = useForm<CheckpointModelConfig>({
|
||||
initialValues: {
|
||||
name: retrievedModel.name,
|
||||
name: retrievedModel.name ? retrievedModel.name : '',
|
||||
base_model: retrievedModel.base_model,
|
||||
type: 'main',
|
||||
path: retrievedModel.path,
|
||||
description: retrievedModel.description,
|
||||
path: retrievedModel.path ? retrievedModel.path : '',
|
||||
description: retrievedModel.description ? retrievedModel.description : '',
|
||||
model_format: 'checkpoint',
|
||||
vae: retrievedModel.vae,
|
||||
config: retrievedModel.config,
|
||||
vae: retrievedModel.vae ? retrievedModel.vae : '',
|
||||
config: retrievedModel.config ? retrievedModel.config : '',
|
||||
variant: retrievedModel.variant,
|
||||
},
|
||||
});
|
||||
|
||||
const editModelFormSubmitHandler = (values) => {
|
||||
console.log(values);
|
||||
const editModelFormSubmitHandler = (values: CheckpointModelConfig) => {
|
||||
const responseBody = {
|
||||
base_model: retrievedModel.base_model,
|
||||
model_name: retrievedModel.name,
|
||||
body: values,
|
||||
};
|
||||
updateMainModel(responseBody);
|
||||
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return modelToEdit ? (
|
||||
@ -88,10 +119,6 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||
)}
|
||||
>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<IAIInput
|
||||
label={t('modelManager.name')}
|
||||
{...checkpointEditForm.getInputProps('name')}
|
||||
/>
|
||||
<IAIInput
|
||||
label={t('modelManager.description')}
|
||||
{...checkpointEditForm.getInputProps('description')}
|
||||
|
@ -6,20 +6,23 @@ import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useForm } from '@mantine/form';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||
import { S } from 'services/api/types';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import { components } from 'services/api/schema';
|
||||
|
||||
type DiffusersModel =
|
||||
| S<'StableDiffusion1ModelDiffusersConfig'>
|
||||
| S<'StableDiffusion2ModelDiffusersConfig'>;
|
||||
export type DiffusersModelConfig =
|
||||
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
|
||||
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
|
||||
|
||||
type DiffusersModelEditProps = {
|
||||
modelToEdit: string;
|
||||
retrievedModel: DiffusersModel;
|
||||
retrievedModel: DiffusersModelConfig;
|
||||
};
|
||||
|
||||
const baseModelSelectData = [
|
||||
@ -39,24 +42,51 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
||||
);
|
||||
const { retrievedModel, modelToEdit } = props;
|
||||
|
||||
const [updateMainModel, { error }] = useUpdateMainModelsMutation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const diffusersEditForm = useForm({
|
||||
const diffusersEditForm = useForm<DiffusersModelConfig>({
|
||||
initialValues: {
|
||||
name: retrievedModel.name,
|
||||
name: retrievedModel.name ? retrievedModel.name : '',
|
||||
base_model: retrievedModel.base_model,
|
||||
type: 'main',
|
||||
path: retrievedModel.path,
|
||||
description: retrievedModel.description,
|
||||
path: retrievedModel.path ? retrievedModel.path : '',
|
||||
description: retrievedModel.description ? retrievedModel.description : '',
|
||||
model_format: 'diffusers',
|
||||
vae: retrievedModel.vae,
|
||||
vae: retrievedModel.vae ? retrievedModel.vae : '',
|
||||
variant: retrievedModel.variant,
|
||||
},
|
||||
});
|
||||
|
||||
const editModelFormSubmitHandler = (values) => {
|
||||
console.log(values);
|
||||
const editModelFormSubmitHandler = (values: DiffusersModelConfig) => {
|
||||
const responseBody = {
|
||||
base_model: retrievedModel.base_model,
|
||||
model_name: retrievedModel.name,
|
||||
body: values,
|
||||
};
|
||||
updateMainModel(responseBody);
|
||||
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
return modelToEdit ? (
|
||||
@ -77,10 +107,6 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
||||
)}
|
||||
>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<IAIInput
|
||||
label={t('modelManager.name')}
|
||||
{...diffusersEditForm.getInputProps('name')}
|
||||
/>
|
||||
<IAIInput
|
||||
label={t('modelManager.description')}
|
||||
{...diffusersEditForm.getInputProps('description')}
|
||||
|
@ -2,6 +2,7 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { cloneDeep } from 'lodash-es';
|
||||
import {
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ControlNetModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
@ -32,6 +33,12 @@ type AnyModelConfigEntity =
|
||||
| TextualInversionModelConfigEntity
|
||||
| VaeModelConfigEntity;
|
||||
|
||||
type UpdateMainModelQuery = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
body: MainModelConfig;
|
||||
};
|
||||
|
||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
});
|
||||
@ -101,6 +108,19 @@ export const modelsApi = api.injectEndpoints({
|
||||
);
|
||||
},
|
||||
}),
|
||||
updateMainModels: build.mutation<
|
||||
EntityState<MainModelConfigEntity>,
|
||||
UpdateMainModelQuery
|
||||
>({
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
return {
|
||||
url: `models/${base_model}/main/${model_name}`,
|
||||
method: 'PATCH',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['MainModel'],
|
||||
}),
|
||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||
providesTags: (result, error, arg) => {
|
||||
@ -244,4 +264,5 @@ export const {
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
useUpdateMainModelsMutation,
|
||||
} = modelsApi;
|
||||
|
Loading…
Reference in New Issue
Block a user