feat: Add Auto Import Model

This commit is contained in:
blessedcoolant 2023-07-16 01:36:00 +12:00
parent dcbb3dc49a
commit e1c0ca1ab2
11 changed files with 167 additions and 294 deletions

View File

@ -454,7 +454,8 @@
"none": "none",
"addDifference": "Add Difference",
"pickModelType": "Pick Model Type",
"selectModel": "Select Model"
"selectModel": "Select Model",
"importModels": "Import Models"
},
"parameters": {
"general": "General",

View File

@ -5,7 +5,7 @@ import AddModelsPanel from './subpanels/AddModelsPanel';
import MergeModelsPanel from './subpanels/MergeModelsPanel';
import ModelManagerPanel from './subpanels/ModelManagerPanel';
type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels';
type ModelManagerTabName = 'modelManager' | 'importModels' | 'mergeModels';
type ModelManagerTabInfo = {
id: ModelManagerTabName;
@ -20,8 +20,8 @@ const tabs: ModelManagerTabInfo[] = [
content: <ModelManagerPanel />,
},
{
id: 'addModels',
label: i18n.t('modelManager.addModel'),
id: 'importModels',
label: i18n.t('modelManager.importModels'),
content: <AddModelsPanel />,
},
{
@ -46,7 +46,7 @@ const ModelManagerTab = () => {
</Tab>
))}
</TabList>
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}>
<TabPanels sx={{ w: 'full', h: 'full' }}>
{tabs.map((tab) => (
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
{tab.content}

View File

@ -1,43 +1,39 @@
import { Divider, Flex, useColorMode } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ButtonGroup, Divider, Flex } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { useState } from 'react';
import { useTranslation } from 'react-i18next';
import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel';
import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel';
import AddModels from './AddModelsPanel/AddModels';
import ScanModels from './AddModelsPanel/ScanModels';
type AddModelTabs = 'add' | 'scan';
export default function AddModelsPanel() {
const addNewModelUIOption = useAppSelector(
(state: RootState) => state.ui.addNewModelUIOption
);
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
const { t } = useTranslation();
return (
<Flex flexDirection="column" gap={4}>
<Flex columnGap={4}>
<ButtonGroup isAttached>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
isChecked={addNewModelUIOption == 'ckpt'}
onClick={() => setAddModelTab('add')}
isChecked={addModelTab == 'add'}
size="sm"
>
{t('modelManager.addCheckpointModel')}
{t('modelManager.addModel')}
</IAIButton>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
isChecked={addNewModelUIOption == 'diffusers'}
onClick={() => setAddModelTab('scan')}
isChecked={addModelTab == 'scan'}
size="sm"
>
{t('modelManager.addDiffuserModel')}
{t('modelManager.scanForModels')}
</IAIButton>
</Flex>
</ButtonGroup>
<Divider />
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
{addModelTab == 'add' && <AddModels />}
{addModelTab == 'scan' && <ScanModels />}
</Flex>
);
}

View File

@ -1,259 +0,0 @@
import {
Flex,
FormControl,
FormErrorMessage,
FormHelperText,
FormLabel,
Text,
VStack,
} from '@chakra-ui/react';
import { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
import { Field, Formik } from 'formik';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import IAIForm from 'common/components/IAIForm';
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
export default function AddDiffusersModel() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
function hasWhiteSpace(s: string) {
return /\s/.test(s);
}
function baseValidation(value: string) {
let error;
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
return error;
}
const addModelFormValues: InvokeDiffusersModelConfigProps = {
name: '',
description: '',
repo_id: '',
path: '',
format: 'diffusers',
default: false,
vae: {
repo_id: '',
path: '',
},
};
const addModelFormSubmitHandler = (
values: InvokeDiffusersModelConfigProps
) => {
const diffusersModelToAdd = values;
if (values.path === '') delete diffusersModelToAdd.path;
if (values.repo_id === '') delete diffusersModelToAdd.repo_id;
if (values.vae.path === '') delete diffusersModelToAdd.vae.path;
if (values.vae.repo_id === '') delete diffusersModelToAdd.vae.repo_id;
dispatch(addNewModel(diffusersModelToAdd));
dispatch(setAddNewModelUIOption(null));
};
return (
<Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit} w="full">
<VStack rowGap={2}>
<IAIFormItemWrapper>
{/* Name */}
<FormControl
isInvalid={!!errors.name && touched.name}
isRequired
>
<FormLabel htmlFor="name" fontSize="sm">
{t('modelManager.name')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="name"
name="name"
type="text"
validate={baseValidation}
isRequired
/>
{!!errors.name && touched.name ? (
<FormErrorMessage>{errors.name}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.nameValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* Description */}
<FormControl
isInvalid={!!errors.description && touched.description}
isRequired
>
<FormLabel htmlFor="description" fontSize="sm">
{t('modelManager.description')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="description"
name="description"
type="text"
isRequired
/>
{!!errors.description && touched.description ? (
<FormErrorMessage>{errors.description}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.descriptionValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
<Text fontWeight="bold" fontSize="sm">
{t('modelManager.formMessageDiffusersModelLocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersModelLocationDesc')}
</Text>
{/* Path */}
<FormControl isInvalid={!!errors.path && touched.path}>
<FormLabel htmlFor="path" fontSize="sm">
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field as={IAIInput} id="path" name="path" type="text" />
{!!errors.path && touched.path ? (
<FormErrorMessage>{errors.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.modelLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* Repo ID */}
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
<FormLabel htmlFor="repo_id" fontSize="sm">
{t('modelManager.repo_id')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="repo_id"
name="repo_id"
type="text"
/>
{!!errors.repo_id && touched.repo_id ? (
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.repoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIFormItemWrapper>
{/* VAE Path */}
<Text fontWeight="bold">
{t('modelManager.formMessageDiffusersVAELocation')}
</Text>
<Text
sx={{
fontSize: 'sm',
fontStyle: 'italic',
}}
variant="subtext"
>
{t('modelManager.formMessageDiffusersVAELocationDesc')}
</Text>
<FormControl
isInvalid={!!errors.vae?.path && touched.vae?.path}
>
<FormLabel htmlFor="vae.path" fontSize="sm">
{t('modelManager.vaeLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.path"
name="vae.path"
type="text"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeLocationValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
{/* VAE Repo ID */}
<FormControl
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
>
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
{t('modelManager.vaeRepoID')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="vae.repo_id"
name="vae.repo_id"
type="text"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
) : (
<FormHelperText margin={0}>
{t('modelManager.vaeRepoIDValidationMsg')}
</FormHelperText>
)}
</VStack>
</FormControl>
</IAIFormItemWrapper>
<IAIButton type="submit" isLoading={isProcessing}>
{t('modelManager.addModel')}
</IAIButton>
</VStack>
</IAIForm>
)}
</Formik>
</Flex>
);
}

View File

@ -0,0 +1,110 @@
import { Flex } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useTranslation } from 'react-i18next';
import { SelectItem } from '@mantine/core';
import { useForm } from '@mantine/form';
import { makeToast } from 'app/components/Toaster';
import { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { addToast } from 'features/system/store/systemSlice';
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
const predictionSelectData: SelectItem[] = [
{ label: 'None', value: 'none' },
{ label: 'v_prediction', value: 'v_prediction' },
{ label: 'epsilon', value: 'epsilon' },
{ label: 'sample', value: 'sample' },
];
type ExtendedImportModelConfig = {
location: string;
prediction_type?: 'v_prediction' | 'epsilon' | 'sample' | 'none' | undefined;
};
export default function AddModels() {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
const addModelForm = useForm<ExtendedImportModelConfig>({
initialValues: {
location: '',
prediction_type: undefined,
},
});
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
const importModelResponseBody = {
location: values.location,
prediction_type:
values.prediction_type === 'none' ? undefined : values.prediction_type,
};
importMainModel({ body: importModelResponseBody })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: 'Model Added',
status: 'success',
})
)
);
addModelForm.reset();
})
.catch((error) => {
if (error) {
console.log(error);
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
};
return (
<form onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))}>
<Flex
flexDirection="column"
overflow="scroll"
maxHeight={window.innerHeight - 270}
width="100%"
gap={4}
>
<IAIMantineTextInput
label="Model Location"
w="100%"
{...addModelForm.getInputProps('location')}
/>
<IAIMantineSelect
label="Prediction Type (used for Stable Diffusion 2.x Models)"
data={predictionSelectData}
defaultValue="none"
{...addModelForm.getInputProps('prediction_type')}
/>
<IAIButton
type="submit"
isLoading={isLoading}
isDisabled={isLoading || isProcessing}
>
{t('modelManager.addModel')}
</IAIButton>
</Flex>
</form>
);
}

View File

@ -22,7 +22,11 @@ export default function FoundModelsList() {
}
return (
<Flex sx={{ flexDirection: 'column' }}>
<Flex
sx={{
flexDirection: 'column',
}}
>
{foundModels.map((model) => (
<Flex key={model}>{model}</Flex>
))}

View File

@ -33,7 +33,7 @@ import SearchModels from './SearchModels';
const MIN_MODEL_SIZE = 64;
const MAX_MODEL_SIZE = 2048;
export default function AddCheckpointModel() {
export default function ScanModels() {
const dispatch = useAppDispatch();
const { t } = useTranslation();

View File

@ -11,7 +11,9 @@ export default function SearchModels() {
return (
<Flex flexDirection="column" w="100%">
<SearchFolderForm />
<FoundModelsList />
<Flex sx={{ maxHeight: window.innerHeight - 400, overflow: 'scroll' }}>
<FoundModelsList />
</Flex>
</Flex>
);
}

View File

@ -1,7 +1,5 @@
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
export type Coordinates = {
x: number;
y: number;
@ -22,7 +20,6 @@ export interface UIState {
shouldUseCanvasBetaLayout: boolean;
shouldShowExistingModelsInSearch: boolean;
shouldUseSliders: boolean;
addNewModelUIOption: AddNewModelType;
shouldHidePreview: boolean;
shouldPinGallery: boolean;
shouldShowGallery: boolean;

View File

@ -7,6 +7,7 @@ import {
ControlNetModelConfig,
ConvertModelConfig,
DiffusersModelConfig,
ImportModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
@ -78,6 +79,13 @@ type MergeMainModelArg = {
type MergeMainModelResponse =
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
type ImportMainModelArg = {
body: ImportModelConfig;
};
type ImportMainModelResponse =
paths['/api/v1/models/import']['post']['responses']['201']['content']['application/json'];
type SearchFolderResponse =
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
@ -168,6 +176,19 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
importMainModels: build.mutation<
ImportMainModelResponse,
ImportMainModelArg
>({
query: ({ body }) => {
return {
url: `models/import`,
method: 'POST',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
DeleteMainModelArg
@ -356,6 +377,7 @@ export const {
useGetVaeModelsQuery,
useUpdateMainModelsMutation,
useDeleteMainModelsMutation,
useImportMainModelsMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
useGetModelsInFolderQuery,

View File

@ -60,7 +60,7 @@ export type AnyModelConfig =
export type MergeModelConfig = components['schemas']['Body_merge_models'];
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
export type SearchFolderConfig = components['schemas'];
export type ImportModelConfig = components['schemas']['Body_import_model'];
// Graphs
export type Graph = components['schemas']['Graph'];