mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Auto Import Model
This commit is contained in:
parent
dcbb3dc49a
commit
e1c0ca1ab2
@ -454,7 +454,8 @@
|
||||
"none": "none",
|
||||
"addDifference": "Add Difference",
|
||||
"pickModelType": "Pick Model Type",
|
||||
"selectModel": "Select Model"
|
||||
"selectModel": "Select Model",
|
||||
"importModels": "Import Models"
|
||||
},
|
||||
"parameters": {
|
||||
"general": "General",
|
||||
|
@ -5,7 +5,7 @@ import AddModelsPanel from './subpanels/AddModelsPanel';
|
||||
import MergeModelsPanel from './subpanels/MergeModelsPanel';
|
||||
import ModelManagerPanel from './subpanels/ModelManagerPanel';
|
||||
|
||||
type ModelManagerTabName = 'modelManager' | 'addModels' | 'mergeModels';
|
||||
type ModelManagerTabName = 'modelManager' | 'importModels' | 'mergeModels';
|
||||
|
||||
type ModelManagerTabInfo = {
|
||||
id: ModelManagerTabName;
|
||||
@ -20,8 +20,8 @@ const tabs: ModelManagerTabInfo[] = [
|
||||
content: <ModelManagerPanel />,
|
||||
},
|
||||
{
|
||||
id: 'addModels',
|
||||
label: i18n.t('modelManager.addModel'),
|
||||
id: 'importModels',
|
||||
label: i18n.t('modelManager.importModels'),
|
||||
content: <AddModelsPanel />,
|
||||
},
|
||||
{
|
||||
@ -46,7 +46,7 @@ const ModelManagerTab = () => {
|
||||
</Tab>
|
||||
))}
|
||||
</TabList>
|
||||
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}>
|
||||
<TabPanels sx={{ w: 'full', h: 'full' }}>
|
||||
{tabs.map((tab) => (
|
||||
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
|
||||
{tab.content}
|
||||
|
@ -1,43 +1,39 @@
|
||||
import { Divider, Flex, useColorMode } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ButtonGroup, Divider, Flex } from '@chakra-ui/react';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import { useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import AddCheckpointModel from './AddModelsPanel/AddCheckpointModel';
|
||||
import AddDiffusersModel from './AddModelsPanel/AddDiffusersModel';
|
||||
import AddModels from './AddModelsPanel/AddModels';
|
||||
import ScanModels from './AddModelsPanel/ScanModels';
|
||||
|
||||
type AddModelTabs = 'add' | 'scan';
|
||||
|
||||
export default function AddModelsPanel() {
|
||||
const addNewModelUIOption = useAppSelector(
|
||||
(state: RootState) => state.ui.addNewModelUIOption
|
||||
);
|
||||
|
||||
const { colorMode } = useColorMode();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4}>
|
||||
<Flex columnGap={4}>
|
||||
<ButtonGroup isAttached>
|
||||
<IAIButton
|
||||
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
|
||||
isChecked={addNewModelUIOption == 'ckpt'}
|
||||
onClick={() => setAddModelTab('add')}
|
||||
isChecked={addModelTab == 'add'}
|
||||
size="sm"
|
||||
>
|
||||
{t('modelManager.addCheckpointModel')}
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
|
||||
isChecked={addNewModelUIOption == 'diffusers'}
|
||||
onClick={() => setAddModelTab('scan')}
|
||||
isChecked={addModelTab == 'scan'}
|
||||
size="sm"
|
||||
>
|
||||
{t('modelManager.addDiffuserModel')}
|
||||
{t('modelManager.scanForModels')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</ButtonGroup>
|
||||
|
||||
<Divider />
|
||||
|
||||
{addNewModelUIOption == 'ckpt' && <AddCheckpointModel />}
|
||||
{addNewModelUIOption == 'diffusers' && <AddDiffusersModel />}
|
||||
{addModelTab == 'add' && <AddModels />}
|
||||
{addModelTab == 'scan' && <ScanModels />}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
@ -1,259 +0,0 @@
|
||||
import {
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
Text,
|
||||
VStack,
|
||||
} from '@chakra-ui/react';
|
||||
import { InvokeDiffusersModelConfigProps } from 'app/types/invokeai';
|
||||
// import { addNewModel } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { setAddNewModelUIOption } from 'features/ui/store/uiSlice';
|
||||
import { Field, Formik } from 'formik';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { RootState } from 'app/store/store';
|
||||
import IAIForm from 'common/components/IAIForm';
|
||||
import { IAIFormItemWrapper } from 'common/components/IAIForms/IAIFormItemWrapper';
|
||||
|
||||
export default function AddDiffusersModel() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
function hasWhiteSpace(s: string) {
|
||||
return /\s/.test(s);
|
||||
}
|
||||
|
||||
function baseValidation(value: string) {
|
||||
let error;
|
||||
if (hasWhiteSpace(value)) error = t('modelManager.cannotUseSpaces');
|
||||
return error;
|
||||
}
|
||||
|
||||
const addModelFormValues: InvokeDiffusersModelConfigProps = {
|
||||
name: '',
|
||||
description: '',
|
||||
repo_id: '',
|
||||
path: '',
|
||||
format: 'diffusers',
|
||||
default: false,
|
||||
vae: {
|
||||
repo_id: '',
|
||||
path: '',
|
||||
},
|
||||
};
|
||||
|
||||
const addModelFormSubmitHandler = (
|
||||
values: InvokeDiffusersModelConfigProps
|
||||
) => {
|
||||
const diffusersModelToAdd = values;
|
||||
|
||||
if (values.path === '') delete diffusersModelToAdd.path;
|
||||
if (values.repo_id === '') delete diffusersModelToAdd.repo_id;
|
||||
if (values.vae.path === '') delete diffusersModelToAdd.vae.path;
|
||||
if (values.vae.repo_id === '') delete diffusersModelToAdd.vae.repo_id;
|
||||
|
||||
dispatch(addNewModel(diffusersModelToAdd));
|
||||
dispatch(setAddNewModelUIOption(null));
|
||||
};
|
||||
|
||||
return (
|
||||
<Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
|
||||
<Formik
|
||||
initialValues={addModelFormValues}
|
||||
onSubmit={addModelFormSubmitHandler}
|
||||
>
|
||||
{({ handleSubmit, errors, touched }) => (
|
||||
<IAIForm onSubmit={handleSubmit} w="full">
|
||||
<VStack rowGap={2}>
|
||||
<IAIFormItemWrapper>
|
||||
{/* Name */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.name && touched.name}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="name" fontSize="sm">
|
||||
{t('modelManager.name')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="name"
|
||||
name="name"
|
||||
type="text"
|
||||
validate={baseValidation}
|
||||
isRequired
|
||||
/>
|
||||
{!!errors.name && touched.name ? (
|
||||
<FormErrorMessage>{errors.name}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.nameValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<IAIFormItemWrapper>
|
||||
{/* Description */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.description && touched.description}
|
||||
isRequired
|
||||
>
|
||||
<FormLabel htmlFor="description" fontSize="sm">
|
||||
{t('modelManager.description')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="description"
|
||||
name="description"
|
||||
type="text"
|
||||
isRequired
|
||||
/>
|
||||
{!!errors.description && touched.description ? (
|
||||
<FormErrorMessage>{errors.description}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.descriptionValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<IAIFormItemWrapper>
|
||||
<Text fontWeight="bold" fontSize="sm">
|
||||
{t('modelManager.formMessageDiffusersModelLocation')}
|
||||
</Text>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
fontStyle: 'italic',
|
||||
}}
|
||||
variant="subtext"
|
||||
>
|
||||
{t('modelManager.formMessageDiffusersModelLocationDesc')}
|
||||
</Text>
|
||||
|
||||
{/* Path */}
|
||||
<FormControl isInvalid={!!errors.path && touched.path}>
|
||||
<FormLabel htmlFor="path" fontSize="sm">
|
||||
{t('modelManager.modelLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field as={IAIInput} id="path" name="path" type="text" />
|
||||
{!!errors.path && touched.path ? (
|
||||
<FormErrorMessage>{errors.path}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.modelLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* Repo ID */}
|
||||
<FormControl isInvalid={!!errors.repo_id && touched.repo_id}>
|
||||
<FormLabel htmlFor="repo_id" fontSize="sm">
|
||||
{t('modelManager.repo_id')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="repo_id"
|
||||
name="repo_id"
|
||||
type="text"
|
||||
/>
|
||||
{!!errors.repo_id && touched.repo_id ? (
|
||||
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.repoIDValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<IAIFormItemWrapper>
|
||||
{/* VAE Path */}
|
||||
<Text fontWeight="bold">
|
||||
{t('modelManager.formMessageDiffusersVAELocation')}
|
||||
</Text>
|
||||
<Text
|
||||
sx={{
|
||||
fontSize: 'sm',
|
||||
fontStyle: 'italic',
|
||||
}}
|
||||
variant="subtext"
|
||||
>
|
||||
{t('modelManager.formMessageDiffusersVAELocationDesc')}
|
||||
</Text>
|
||||
<FormControl
|
||||
isInvalid={!!errors.vae?.path && touched.vae?.path}
|
||||
>
|
||||
<FormLabel htmlFor="vae.path" fontSize="sm">
|
||||
{t('modelManager.vaeLocation')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae.path"
|
||||
name="vae.path"
|
||||
type="text"
|
||||
/>
|
||||
{!!errors.vae?.path && touched.vae?.path ? (
|
||||
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.vaeLocationValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
|
||||
{/* VAE Repo ID */}
|
||||
<FormControl
|
||||
isInvalid={!!errors.vae?.repo_id && touched.vae?.repo_id}
|
||||
>
|
||||
<FormLabel htmlFor="vae.repo_id" fontSize="sm">
|
||||
{t('modelManager.vaeRepoID')}
|
||||
</FormLabel>
|
||||
<VStack alignItems="start">
|
||||
<Field
|
||||
as={IAIInput}
|
||||
id="vae.repo_id"
|
||||
name="vae.repo_id"
|
||||
type="text"
|
||||
/>
|
||||
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
|
||||
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>
|
||||
) : (
|
||||
<FormHelperText margin={0}>
|
||||
{t('modelManager.vaeRepoIDValidationMsg')}
|
||||
</FormHelperText>
|
||||
)}
|
||||
</VStack>
|
||||
</FormControl>
|
||||
</IAIFormItemWrapper>
|
||||
|
||||
<IAIButton type="submit" isLoading={isProcessing}>
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
</VStack>
|
||||
</IAIForm>
|
||||
)}
|
||||
</Formik>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -0,0 +1,110 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
// import { addNewModel } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useForm } from '@mantine/form';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
import { RootState } from 'app/store/store';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
const predictionSelectData: SelectItem[] = [
|
||||
{ label: 'None', value: 'none' },
|
||||
{ label: 'v_prediction', value: 'v_prediction' },
|
||||
{ label: 'epsilon', value: 'epsilon' },
|
||||
{ label: 'sample', value: 'sample' },
|
||||
];
|
||||
|
||||
type ExtendedImportModelConfig = {
|
||||
location: string;
|
||||
prediction_type?: 'v_prediction' | 'epsilon' | 'sample' | 'none' | undefined;
|
||||
};
|
||||
|
||||
export default function AddModels() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isProcessing = useAppSelector(
|
||||
(state: RootState) => state.system.isProcessing
|
||||
);
|
||||
|
||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||
|
||||
const addModelForm = useForm<ExtendedImportModelConfig>({
|
||||
initialValues: {
|
||||
location: '',
|
||||
prediction_type: undefined,
|
||||
},
|
||||
});
|
||||
|
||||
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
|
||||
const importModelResponseBody = {
|
||||
location: values.location,
|
||||
prediction_type:
|
||||
values.prediction_type === 'none' ? undefined : values.prediction_type,
|
||||
};
|
||||
|
||||
importMainModel({ body: importModelResponseBody })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: 'Model Added',
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
addModelForm.reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
console.log(error);
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))}>
|
||||
<Flex
|
||||
flexDirection="column"
|
||||
overflow="scroll"
|
||||
maxHeight={window.innerHeight - 270}
|
||||
width="100%"
|
||||
gap={4}
|
||||
>
|
||||
<IAIMantineTextInput
|
||||
label="Model Location"
|
||||
w="100%"
|
||||
{...addModelForm.getInputProps('location')}
|
||||
/>
|
||||
<IAIMantineSelect
|
||||
label="Prediction Type (used for Stable Diffusion 2.x Models)"
|
||||
data={predictionSelectData}
|
||||
defaultValue="none"
|
||||
{...addModelForm.getInputProps('prediction_type')}
|
||||
/>
|
||||
<IAIButton
|
||||
type="submit"
|
||||
isLoading={isLoading}
|
||||
isDisabled={isLoading || isProcessing}
|
||||
>
|
||||
{t('modelManager.addModel')}
|
||||
</IAIButton>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
}
|
@ -22,7 +22,11 @@ export default function FoundModelsList() {
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex sx={{ flexDirection: 'column' }}>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
}}
|
||||
>
|
||||
{foundModels.map((model) => (
|
||||
<Flex key={model}>{model}</Flex>
|
||||
))}
|
||||
|
@ -33,7 +33,7 @@ import SearchModels from './SearchModels';
|
||||
const MIN_MODEL_SIZE = 64;
|
||||
const MAX_MODEL_SIZE = 2048;
|
||||
|
||||
export default function AddCheckpointModel() {
|
||||
export default function ScanModels() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
@ -11,7 +11,9 @@ export default function SearchModels() {
|
||||
return (
|
||||
<Flex flexDirection="column" w="100%">
|
||||
<SearchFolderForm />
|
||||
<Flex sx={{ maxHeight: window.innerHeight - 400, overflow: 'scroll' }}>
|
||||
<FoundModelsList />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
@ -1,7 +1,5 @@
|
||||
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||
|
||||
export type AddNewModelType = 'ckpt' | 'diffusers' | null;
|
||||
|
||||
export type Coordinates = {
|
||||
x: number;
|
||||
y: number;
|
||||
@ -22,7 +20,6 @@ export interface UIState {
|
||||
shouldUseCanvasBetaLayout: boolean;
|
||||
shouldShowExistingModelsInSearch: boolean;
|
||||
shouldUseSliders: boolean;
|
||||
addNewModelUIOption: AddNewModelType;
|
||||
shouldHidePreview: boolean;
|
||||
shouldPinGallery: boolean;
|
||||
shouldShowGallery: boolean;
|
||||
|
@ -7,6 +7,7 @@ import {
|
||||
ControlNetModelConfig,
|
||||
ConvertModelConfig,
|
||||
DiffusersModelConfig,
|
||||
ImportModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
MergeModelConfig,
|
||||
@ -78,6 +79,13 @@ type MergeMainModelArg = {
|
||||
type MergeMainModelResponse =
|
||||
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ImportMainModelArg = {
|
||||
body: ImportModelConfig;
|
||||
};
|
||||
|
||||
type ImportMainModelResponse =
|
||||
paths['/api/v1/models/import']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
type SearchFolderResponse =
|
||||
paths['/api/v1/models/search']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
@ -168,6 +176,19 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
}),
|
||||
importMainModels: build.mutation<
|
||||
ImportMainModelResponse,
|
||||
ImportMainModelArg
|
||||
>({
|
||||
query: ({ body }) => {
|
||||
return {
|
||||
url: `models/import`,
|
||||
method: 'POST',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
}),
|
||||
deleteMainModels: build.mutation<
|
||||
DeleteMainModelResponse,
|
||||
DeleteMainModelArg
|
||||
@ -356,6 +377,7 @@ export const {
|
||||
useGetVaeModelsQuery,
|
||||
useUpdateMainModelsMutation,
|
||||
useDeleteMainModelsMutation,
|
||||
useImportMainModelsMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useMergeMainModelsMutation,
|
||||
useGetModelsInFolderQuery,
|
||||
|
@ -60,7 +60,7 @@ export type AnyModelConfig =
|
||||
|
||||
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
||||
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
|
||||
export type SearchFolderConfig = components['schemas'];
|
||||
export type ImportModelConfig = components['schemas']['Body_import_model'];
|
||||
|
||||
// Graphs
|
||||
export type Graph = components['schemas']['Graph'];
|
||||
|
Loading…
Reference in New Issue
Block a user