mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Restore Update Model functionality
This commit is contained in:
parent
af239fa122
commit
afb46564e8
@ -342,6 +342,7 @@
|
|||||||
"safetensorModels": "SafeTensors",
|
"safetensorModels": "SafeTensors",
|
||||||
"modelAdded": "Model Added",
|
"modelAdded": "Model Added",
|
||||||
"modelUpdated": "Model Updated",
|
"modelUpdated": "Model Updated",
|
||||||
|
"modelUpdateFailed": "Model Update Failed",
|
||||||
"modelEntryDeleted": "Model Entry Deleted",
|
"modelEntryDeleted": "Model Entry Deleted",
|
||||||
"cannotUseSpaces": "Cannot Use Spaces",
|
"cannotUseSpaces": "Cannot Use Spaces",
|
||||||
"addNew": "Add New",
|
"addNew": "Add New",
|
||||||
|
@ -11,7 +11,11 @@ import IAIButton from 'common/components/IAIButton';
|
|||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||||
import { S } from 'services/api/types';
|
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
|
import { components } from 'services/api/schema';
|
||||||
import ModelConvert from './ModelConvert';
|
import ModelConvert from './ModelConvert';
|
||||||
|
|
||||||
const baseModelSelectData = [
|
const baseModelSelectData = [
|
||||||
@ -25,13 +29,13 @@ const variantSelectData = [
|
|||||||
{ value: 'depth', label: 'Depth' },
|
{ value: 'depth', label: 'Depth' },
|
||||||
];
|
];
|
||||||
|
|
||||||
export type CheckpointModel =
|
export type CheckpointModelConfig =
|
||||||
| S<'StableDiffusion1ModelCheckpointConfig'>
|
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
|
||||||
| S<'StableDiffusion2ModelCheckpointConfig'>;
|
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
|
||||||
|
|
||||||
type CheckpointModelEditProps = {
|
type CheckpointModelEditProps = {
|
||||||
modelToEdit: string;
|
modelToEdit: string;
|
||||||
retrievedModel: CheckpointModel;
|
retrievedModel: CheckpointModelConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
||||||
@ -41,25 +45,52 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
|||||||
|
|
||||||
const { modelToEdit, retrievedModel } = props;
|
const { modelToEdit, retrievedModel } = props;
|
||||||
|
|
||||||
|
const [updateMainModel, { error }] = useUpdateMainModelsMutation();
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const checkpointEditForm = useForm({
|
const checkpointEditForm = useForm<CheckpointModelConfig>({
|
||||||
initialValues: {
|
initialValues: {
|
||||||
name: retrievedModel.name,
|
name: retrievedModel.name ? retrievedModel.name : '',
|
||||||
base_model: retrievedModel.base_model,
|
base_model: retrievedModel.base_model,
|
||||||
type: 'main',
|
type: 'main',
|
||||||
path: retrievedModel.path,
|
path: retrievedModel.path ? retrievedModel.path : '',
|
||||||
description: retrievedModel.description,
|
description: retrievedModel.description ? retrievedModel.description : '',
|
||||||
model_format: 'checkpoint',
|
model_format: 'checkpoint',
|
||||||
vae: retrievedModel.vae,
|
vae: retrievedModel.vae ? retrievedModel.vae : '',
|
||||||
config: retrievedModel.config,
|
config: retrievedModel.config ? retrievedModel.config : '',
|
||||||
variant: retrievedModel.variant,
|
variant: retrievedModel.variant,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const editModelFormSubmitHandler = (values) => {
|
const editModelFormSubmitHandler = (values: CheckpointModelConfig) => {
|
||||||
console.log(values);
|
const responseBody = {
|
||||||
|
base_model: retrievedModel.base_model,
|
||||||
|
model_name: retrievedModel.name,
|
||||||
|
body: values,
|
||||||
|
};
|
||||||
|
updateMainModel(responseBody);
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdateFailed'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdated'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
return modelToEdit ? (
|
return modelToEdit ? (
|
||||||
@ -88,10 +119,6 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||||
<IAIInput
|
|
||||||
label={t('modelManager.name')}
|
|
||||||
{...checkpointEditForm.getInputProps('name')}
|
|
||||||
/>
|
|
||||||
<IAIInput
|
<IAIInput
|
||||||
label={t('modelManager.description')}
|
label={t('modelManager.description')}
|
||||||
{...checkpointEditForm.getInputProps('description')}
|
{...checkpointEditForm.getInputProps('description')}
|
||||||
|
@ -6,20 +6,23 @@ import { Divider, Flex, Text } from '@chakra-ui/react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { useForm } from '@mantine/form';
|
import { useForm } from '@mantine/form';
|
||||||
|
import { makeToast } from 'app/components/Toaster';
|
||||||
import type { RootState } from 'app/store/store';
|
import type { RootState } from 'app/store/store';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
import IAIInput from 'common/components/IAIInput';
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
import { MODEL_TYPE_MAP } from 'features/system/components/ModelSelect';
|
||||||
import { S } from 'services/api/types';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { useUpdateMainModelsMutation } from 'services/api/endpoints/models';
|
||||||
|
import { components } from 'services/api/schema';
|
||||||
|
|
||||||
type DiffusersModel =
|
export type DiffusersModelConfig =
|
||||||
| S<'StableDiffusion1ModelDiffusersConfig'>
|
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
|
||||||
| S<'StableDiffusion2ModelDiffusersConfig'>;
|
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
|
||||||
|
|
||||||
type DiffusersModelEditProps = {
|
type DiffusersModelEditProps = {
|
||||||
modelToEdit: string;
|
modelToEdit: string;
|
||||||
retrievedModel: DiffusersModel;
|
retrievedModel: DiffusersModelConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
const baseModelSelectData = [
|
const baseModelSelectData = [
|
||||||
@ -39,24 +42,51 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
|||||||
);
|
);
|
||||||
const { retrievedModel, modelToEdit } = props;
|
const { retrievedModel, modelToEdit } = props;
|
||||||
|
|
||||||
|
const [updateMainModel, { error }] = useUpdateMainModelsMutation();
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const diffusersEditForm = useForm({
|
const diffusersEditForm = useForm<DiffusersModelConfig>({
|
||||||
initialValues: {
|
initialValues: {
|
||||||
name: retrievedModel.name,
|
name: retrievedModel.name ? retrievedModel.name : '',
|
||||||
base_model: retrievedModel.base_model,
|
base_model: retrievedModel.base_model,
|
||||||
type: 'main',
|
type: 'main',
|
||||||
path: retrievedModel.path,
|
path: retrievedModel.path ? retrievedModel.path : '',
|
||||||
description: retrievedModel.description,
|
description: retrievedModel.description ? retrievedModel.description : '',
|
||||||
model_format: 'diffusers',
|
model_format: 'diffusers',
|
||||||
vae: retrievedModel.vae,
|
vae: retrievedModel.vae ? retrievedModel.vae : '',
|
||||||
variant: retrievedModel.variant,
|
variant: retrievedModel.variant,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const editModelFormSubmitHandler = (values) => {
|
const editModelFormSubmitHandler = (values: DiffusersModelConfig) => {
|
||||||
console.log(values);
|
const responseBody = {
|
||||||
|
base_model: retrievedModel.base_model,
|
||||||
|
model_name: retrievedModel.name,
|
||||||
|
body: values,
|
||||||
|
};
|
||||||
|
updateMainModel(responseBody);
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdateFailed'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdated'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
return modelToEdit ? (
|
return modelToEdit ? (
|
||||||
@ -77,10 +107,6 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
|
|||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||||
<IAIInput
|
|
||||||
label={t('modelManager.name')}
|
|
||||||
{...diffusersEditForm.getInputProps('name')}
|
|
||||||
/>
|
|
||||||
<IAIInput
|
<IAIInput
|
||||||
label={t('modelManager.description')}
|
label={t('modelManager.description')}
|
||||||
{...diffusersEditForm.getInputProps('description')}
|
{...diffusersEditForm.getInputProps('description')}
|
||||||
|
@ -2,6 +2,7 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
|
|||||||
import { cloneDeep } from 'lodash-es';
|
import { cloneDeep } from 'lodash-es';
|
||||||
import {
|
import {
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
ControlNetModelConfig,
|
ControlNetModelConfig,
|
||||||
LoRAModelConfig,
|
LoRAModelConfig,
|
||||||
MainModelConfig,
|
MainModelConfig,
|
||||||
@ -32,6 +33,12 @@ type AnyModelConfigEntity =
|
|||||||
| TextualInversionModelConfigEntity
|
| TextualInversionModelConfigEntity
|
||||||
| VaeModelConfigEntity;
|
| VaeModelConfigEntity;
|
||||||
|
|
||||||
|
type UpdateMainModelQuery = {
|
||||||
|
base_model: BaseModelType;
|
||||||
|
model_name: string;
|
||||||
|
body: MainModelConfig;
|
||||||
|
};
|
||||||
|
|
||||||
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
});
|
});
|
||||||
@ -101,6 +108,19 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
updateMainModels: build.mutation<
|
||||||
|
EntityState<MainModelConfigEntity>,
|
||||||
|
UpdateMainModelQuery
|
||||||
|
>({
|
||||||
|
query: ({ base_model, model_name, body }) => {
|
||||||
|
return {
|
||||||
|
url: `models/${base_model}/main/${model_name}`,
|
||||||
|
method: 'PATCH',
|
||||||
|
body: body,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: ['MainModel'],
|
||||||
|
}),
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
|
||||||
providesTags: (result, error, arg) => {
|
providesTags: (result, error, arg) => {
|
||||||
@ -244,4 +264,5 @@ export const {
|
|||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
useGetTextualInversionModelsQuery,
|
useGetTextualInversionModelsQuery,
|
||||||
useGetVaeModelsQuery,
|
useGetVaeModelsQuery,
|
||||||
|
useUpdateMainModelsMutation,
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
Loading…
Reference in New Issue
Block a user