feat(ui): update UI to use new model config backend

- Update all queries
- Remove Advanced Add
- Removed un-editable, internal-only model attributes from model edit UI (e.g. format, repo variant, model type)
- Update model tags so the list refreshes when a model installs
- Rename some queries, components, variables, types to match backend
- Fix divide-by-zero in install queue
This commit is contained in:
psychedelicious 2024-03-05 19:04:13 +11:00
parent 48119d9010
commit 99407c899f
30 changed files with 993 additions and 1824 deletions

View File

@ -34,13 +34,13 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
return; return;
} }
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap(); const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
if (!metadata || !metadata.default_settings) { if (!modelConfig || !modelConfig.default_settings) {
return; return;
} }
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings; const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
if (vae) { if (vae) {
// we store this as "default" within default settings // we store this as "default" within default settings

View File

@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { bytes, total_bytes, id } = action.payload.data; const { bytes, total_bytes, id } = action.payload.data;
dispatch( dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.bytes = bytes; modelImport.bytes = bytes;
@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id } = action.payload.data; const { id } = action.payload.data;
dispatch( dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.status = 'completed'; modelImport.status = 'completed';
@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }])); dispatch(api.util.invalidateTags(['Model']));
}, },
}); });
@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id, error, error_type } = action.payload.data; const { id, error, error_type } = action.payload.data;
dispatch( dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
if (modelImport) { if (modelImport) {
modelImport.status = 'error'; modelImport.status = 'error';

View File

@ -1,228 +0,0 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { isNil, omitBy } from 'lodash-es';
import { useCallback, useEffect } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const AdvancedImport = () => {
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
setValue,
resetField,
reset,
watch,
} = useForm<AnyModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'diffusers',
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
installModel({
source: values.path,
config: omitBy(values, isNil),
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[installModel, dispatch, t, reset]
);
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
useEffect(() => {
if (watchedModelType === 'main') {
setValue('format', 'diffusers');
setValue('repo_variant', '');
setValue('variant', 'normal');
}
if (watchedModelType === 'lora') {
setValue('format', 'lycoris');
} else if (watchedModelType === 'embedding') {
setValue('format', 'embedding_file');
} else if (watchedModelType === 'ip_adapter') {
setValue('format', 'invokeai');
} else {
setValue('format', 'diffusers');
}
resetField('upcast_attention');
resetField('ztsnr_training');
resetField('vae');
resetField('config');
resetField('prediction_type');
resetField('image_encoder_model_id');
}, [watchedModelType, resetField, setValue]);
return (
<ScrollableContent>
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
<Flex alignItems="flex-end" gap="4">
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl>
<Text px="2" fontSize="xs" textAlign="center">
{t('modelManager.advancedImportInfo')}
</Text>
</Flex>
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
<Input {...register('image_encoder_model_id')} />
</FormControl>
</Flex>
)}
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</form>
</ScrollableContent>
);
};

View File

@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
location: string; location: string;
}; };
export const SimpleImport = () => { export const InstallModelForm = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [installModel, { isLoading }] = useInstallModelMutation(); const [installModel, { isLoading }] = useInstallModelMutation();

View File

@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models'; import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
import { ImportQueueItem } from './ImportQueueItem'; import { ModelInstallQueueItem } from './ModelInstallQueueItem';
export const ImportQueue = () => { export const ModelInstallQueue = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetModelImportsQuery(); const { data } = useListModelInstallsQuery();
const [pruneModelImports] = usePruneModelImportsMutation(); const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
const pruneQueue = useCallback(() => { const pruneCompletedModelInstalls = useCallback(() => {
pruneModelImports() _pruneCompletedModelInstalls()
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -41,7 +41,7 @@ export const ImportQueue = () => {
); );
} }
}); });
}, [pruneModelImports, dispatch]); }, [_pruneCompletedModelInstalls, dispatch]);
const pruneAvailable = useMemo(() => { const pruneAvailable = useMemo(() => {
return data?.some( return data?.some(
@ -53,14 +53,19 @@ export const ImportQueue = () => {
<Flex flexDir="column" p={3} h="full"> <Flex flexDir="column" p={3} h="full">
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Text>{t('modelManager.importQueue')}</Text> <Text>{t('modelManager.importQueue')}</Text>
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}> <Button
size="sm"
isDisabled={!pruneAvailable}
onClick={pruneCompletedModelInstalls}
tooltip={t('modelManager.pruneTooltip')}
>
{t('modelManager.prune')} {t('modelManager.prune')}
</Button> </Button>
</Flex> </Flex>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full"> <Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<ScrollableContent> <ScrollableContent>
<Flex flexDir="column-reverse" gap="2"> <Flex flexDir="column-reverse" gap="2">
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)} {data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
</Flex> </Flex>
</ScrollableContent> </ScrollableContent>
</Box> </Box>

View File

@ -6,17 +6,24 @@ import type { ModelInstallStatus } from 'services/api/types';
const STATUSES = { const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' }, waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' }, downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' }, running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
completed: { colorScheme: 'green', translationKey: 'queue.completed' }, completed: { colorScheme: 'green', translationKey: 'queue.completed' },
error: { colorScheme: 'red', translationKey: 'queue.failed' }, error: { colorScheme: 'red', translationKey: 'queue.failed' },
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' }, cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
}; };
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => { const ModelInstallQueueBadge = ({
status,
errorReason,
}: {
status?: ModelInstallStatus;
errorReason?: string | null;
}) => {
const { t } = useTranslation(); const { t } = useTranslation();
if (!status || !Object.keys(STATUSES).includes(status)) { if (!status || !Object.keys(STATUSES).includes(status)) {
return <></>; return null;
} }
return ( return (
@ -25,4 +32,4 @@ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus
</Tooltip> </Tooltip>
); );
}; };
export default memo(ImportQueueBadge); export default memo(ModelInstallQueueBadge);

View File

@ -3,15 +3,16 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next'; import { t } from 'i18next';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models'; import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types'; import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
import ImportQueueBadge from './ImportQueueBadge'; import ModelInstallQueueBadge from './ModelInstallQueueBadge';
type ModelListItemProps = { type ModelListItemProps = {
model: ModelInstallJob; installJob: ModelInstallJob;
}; };
const formatBytes = (bytes: number) => { const formatBytes = (bytes: number) => {
@ -26,26 +27,26 @@ const formatBytes = (bytes: number) => {
return `${bytes.toFixed(2)} ${units[i]}`; return `${bytes.toFixed(2)} ${units[i]}`;
}; };
export const ImportQueueItem = (props: ModelListItemProps) => { export const ModelInstallQueueItem = (props: ModelListItemProps) => {
const { model } = props; const { installJob } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation(); const [deleteImportModel] = useCancelModelInstallMutation();
const source = useMemo(() => { const source = useMemo(() => {
if (model.source.type === 'hf') { if (installJob.source.type === 'hf') {
return model.source as HFModelSource; return installJob.source as HFModelSource;
} else if (model.source.type === 'local') { } else if (installJob.source.type === 'local') {
return model.source as LocalModelSource; return installJob.source as LocalModelSource;
} else if (model.source.type === 'url') { } else if (installJob.source.type === 'url') {
return model.source as URLModelSource; return installJob.source as URLModelSource;
} else { } else {
return model.source as LocalModelSource; return installJob.source as LocalModelSource;
} }
}, [model.source]); }, [installJob.source]);
const handleDeleteModelImport = useCallback(() => { const handleDeleteModelImport = useCallback(() => {
deleteImportModel(model.id) deleteImportModel(installJob.id)
.unwrap() .unwrap()
.then((_) => { .then((_) => {
dispatch( dispatch(
@ -69,7 +70,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
); );
} }
}); });
}, [deleteImportModel, model, dispatch]); }, [deleteImportModel, installJob, dispatch]);
const modelName = useMemo(() => { const modelName = useMemo(() => {
switch (source.type) { switch (source.type) {
@ -85,19 +86,23 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
}, [source]); }, [source]);
const progressValue = useMemo(() => { const progressValue = useMemo(() => {
if (model.bytes === undefined || model.total_bytes === undefined) { if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
return null;
}
if (installJob.total_bytes === 0) {
return 0; return 0;
} }
return (model.bytes / model.total_bytes) * 100; return (installJob.bytes / installJob.total_bytes) * 100;
}, [model.bytes, model.total_bytes]); }, [installJob.bytes, installJob.total_bytes]);
const progressString = useMemo(() => { const progressString = useMemo(() => {
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) { if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
return ''; return '';
} }
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`; return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
}, [model.bytes, model.total_bytes, model.status]); }, [installJob.bytes, installJob.total_bytes, installJob.status]);
return ( return (
<Flex gap="2" w="full" alignItems="center"> <Flex gap="2" w="full" alignItems="center">
@ -109,19 +114,21 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
<Flex flexDir="column" flex={1}> <Flex flexDir="column" flex={1}>
<Tooltip label={progressString}> <Tooltip label={progressString}>
<Progress <Progress
value={progressValue} value={progressValue ?? 0}
isIndeterminate={progressValue === undefined} isIndeterminate={progressValue === null}
aria-label={t('accessibility.invokeProgressBar')} aria-label={t('accessibility.invokeProgressBar')}
h={2} h={2}
/> />
</Tooltip> </Tooltip>
</Flex> </Flex>
<Box minW="100px" textAlign="center"> <Box minW="100px" textAlign="center">
<ImportQueueBadge status={model.status} errorReason={model.error_reason} /> <ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
</Box> </Box>
<Box minW="20px"> <Box minW="20px">
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && ( {(installJob.status === 'downloading' ||
installJob.status === 'waiting' ||
installJob.status === 'running') && (
<IconButton <IconButton
isRound={true} isRound={true}
size="xs" size="xs"

View File

@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
import type { ChangeEventHandler } from 'react'; import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react'; import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useLazyScanModelsQuery } from 'services/api/endpoints/models'; import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
import { ScanModelsResults } from './ScanModelsResults'; import { ScanModelsResults } from './ScanFolderResults';
export const ScanModelsForm = () => { export const ScanModelsForm = () => {
const [scanPath, setScanPath] = useState(''); const [scanPath, setScanPath] = useState('');
const [errorMessage, setErrorMessage] = useState(''); const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation(); const { t } = useTranslation();
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery(); const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
const handleSubmitScan = useCallback(async () => { const scanFolder = useCallback(async () => {
_scanModels({ scan_path: scanPath }).catch((error) => { _scanFolder({ scan_path: scanPath }).catch((error) => {
if (error) { if (error) {
setErrorMessage(error.data.detail); setErrorMessage(error.data.detail);
} }
}); });
}, [_scanModels, scanPath]); }, [_scanFolder, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => { const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setScanPath(e.target.value); setScanPath(e.target.value);
@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
<Input value={scanPath} onChange={handleSetScanPath} /> <Input value={scanPath} onChange={handleSetScanPath} />
</Flex> </Flex>
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}> <Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
{t('modelManager.scanFolder')} {t('modelManager.scanFolder')}
</Button> </Button>
</Flex> </Flex>

View File

@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi'; import { PiXBold } from 'react-icons/pi';
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models'; import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
import { ScanModelResultItem } from './ScanModelResultItem'; import { ScanModelResultItem } from './ScanFolderResultItem';
type ScanModelResultsProps = { type ScanModelResultsProps = {
results: ScanFolderResponse; results: ScanFolderResponse;

View File

@ -1,12 +1,11 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library'; import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { AdvancedImport } from './AddModelPanel/AdvancedImport'; import { InstallModelForm } from './AddModelPanel/InstallModelForm';
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue'; import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm'; import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
import { SimpleImport } from './AddModelPanel/SimpleImport';
export const ImportModels = () => { export const InstallModels = () => {
const { t } = useTranslation(); const { t } = useTranslation();
return ( return (
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}> <Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
@ -17,15 +16,11 @@ export const ImportModels = () => {
<Tabs variant="collapse" height="100%"> <Tabs variant="collapse" height="100%">
<TabList> <TabList>
<Tab>{t('common.simple')}</Tab> <Tab>{t('common.simple')}</Tab>
<Tab>{t('modelManager.advanced')}</Tab>
<Tab>{t('modelManager.scan')}</Tab> <Tab>{t('modelManager.scan')}</Tab>
</TabList> </TabList>
<TabPanels p={3} height="100%"> <TabPanels p={3} height="100%">
<TabPanel> <TabPanel>
<SimpleImport /> <InstallModelForm />
</TabPanel>
<TabPanel height="100%">
<AdvancedImport />
</TabPanel> </TabPanel>
<TabPanel height="100%"> <TabPanel height="100%">
<ScanModelsForm /> <ScanModelsForm />
@ -34,7 +29,7 @@ export const ImportModels = () => {
</Tabs> </Tabs>
</Box> </Box>
<Box layerStyle="second" borderRadius="base" w="full" h="50%"> <Box layerStyle="second" borderRadius="base" w="full" h="50%">
<ImportQueue /> <ModelInstallQueue />
</Box> </Box>
</Flex> </Flex>
); );

View File

@ -5,7 +5,7 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoFilter } from 'react-icons/io5'; import { IoFilter } from 'react-icons/io5';
export const MODEL_TYPE_LABELS: { [key: string]: string } = { const MODEL_TYPE_LABELS: { [key: string]: string } = {
main: 'Main', main: 'Main',
lora: 'LoRA', lora: 'LoRA',
embedding: 'Textual Inversion', embedding: 'Textual Inversion',

View File

@ -1,14 +1,14 @@
import { Box } from '@invoke-ai/ui-library'; import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { ImportModels } from './ImportModels'; import { InstallModels } from './InstallModels';
import { Model } from './ModelPanel/Model'; import { Model } from './ModelPanel/Model';
export const ModelPane = () => { export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return ( return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full"> <Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />} {selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
</Box> </Box>
); );
}; };

View File

@ -5,7 +5,7 @@ import Loading from 'common/components/Loading/Loading';
import { selectConfigSlice } from 'features/system/store/configSlice'; import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es'; import { isNil } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm'; import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
@ -24,7 +24,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
export const DefaultSettings = () => { export const DefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } = const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
useAppSelector(initialStatesSelector); useAppSelector(initialStatesSelector);

View File

@ -8,7 +8,7 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5'; import { IoPencil } from 'react-icons/io5';
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models'; import { useUpdateModelMutation } from 'services/api/endpoints/models';
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier'; import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
import { DefaultCfgScale } from './DefaultCfgScale'; import { DefaultCfgScale } from './DefaultCfgScale';
@ -41,7 +41,7 @@ export const DefaultSettingsForm = ({
const { t } = useTranslation(); const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation(); const [updateModel, { isLoading }] = useUpdateModelMutation();
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({ const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults, defaultValues: defaultSettingsDefaults,
@ -62,7 +62,7 @@ export const DefaultSettingsForm = ({
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null, scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
}; };
editModelMetadata({ updateModel({
key: selectedModelKey, key: selectedModelKey,
body: { default_settings: body }, body: { default_settings: body },
}) })
@ -90,7 +90,7 @@ export const DefaultSettingsForm = ({
} }
}); });
}, },
[selectedModelKey, dispatch, editModelMetadata, t] [selectedModelKey, dispatch, updateModel, t]
); );
return ( return (

View File

@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form'; import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form'; import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types'; import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [ const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] }, { value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
@ -14,8 +14,12 @@ const options: ComboboxOption[] = [
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] }, { value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
]; ];
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => { type Props = {
const { field } = useController(props); control: Control<UpdateModelArg['body']>;
};
const BaseModelSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'base' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>( const onChange = useCallback<ComboboxOnChange>(
(v) => { (v) => {

View File

@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'true', label: 'True' },
{ value: 'false', label: 'False' },
];
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(BooleanSelect);

View File

@ -1,47 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field, formState } = useController(props);
const type = useWatch({ control: props.control, name: 'type' });
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
const options: ComboboxOption[] = useMemo(() => {
const modelType = type || formState.defaultValues?.type;
if (modelType === 'lora') {
return [
{ value: 'lycoris', label: 'LyCORIS' },
{ value: 'diffusers', label: 'Diffusers' },
];
} else if (modelType === 'embedding') {
return [
{ value: 'embedding_file', label: 'Embedding File' },
{ value: 'embedding_folder', label: 'Embedding Folder' },
];
} else if (modelType === 'ip_adapter') {
return [{ value: 'invokeai', label: 'invokeai' }];
} else {
return [
{ value: 'diffusers', label: 'Diffusers' },
{ value: 'checkpoint', label: 'Checkpoint' },
];
}
}, [type, formState.defaultValues?.type]);
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelFormatSelect);

View File

@ -1,32 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
] as const;
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelTypeSelect);

View File

@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library'; import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form'; import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form'; import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types'; import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [ const options: ComboboxOption[] = [
{ value: 'normal', label: 'Normal' }, { value: 'normal', label: 'Normal' },
@ -12,8 +12,12 @@ const options: ComboboxOption[] = [
{ value: 'depth', label: 'Depth' }, { value: 'depth', label: 'Depth' },
]; ];
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => { type Props = {
const { field } = useController(props); control: Control<UpdateModelArg['body']>;
};
const ModelVariantSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'variant' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>( const onChange = useCallback<ComboboxOnChange>(
(v) => { (v) => {

View File

@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library'; import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form'; import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form'; import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types'; import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [ const options: ComboboxOption[] = [
{ value: 'none', label: '-' }, { value: 'none', label: '-' },
@ -13,8 +13,12 @@ const options: ComboboxOption[] = [
{ value: 'sample', label: 'sample' }, { value: 'sample', label: 'sample' },
]; ];
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => { type Props = {
const { field } = useController(props); control: Control<UpdateModelArg['body']>;
};
const PredictionTypeSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'prediction_type' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]); const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>( const onChange = useCallback<ComboboxOnChange>(
(v) => { (v) => {

View File

@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'fp16', label: 'fp16' },
{ value: 'fp32', label: 'fp32' },
];
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(RepoVariantSelect);

View File

@ -2,16 +2,16 @@ import { Flex } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query'; import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
export const ModelMetadata = () => { export const ModelMetadata = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken); const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
return ( return (
<> <>
<Flex flexDir="column" height="full" gap="3"> <Flex flexDir="column" height="full" gap="3">
<DataViewer label="metadata" data={metadata || {}} /> <DataViewer label="metadata" data={data?.source_api_response || {}} />
</Flex> </Flex>
</> </>
); );

View File

@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models'; import { useConvertModelMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types'; import type { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps { interface ModelConvertProps {
@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
const { model } = props; const { model } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation(); const [convertModel, { isLoading }] = useConvertModelMutation();
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
const modelConvertHandler = useCallback(() => { const modelConvertHandler = useCallback(() => {

View File

@ -1,5 +1,6 @@
import { import {
Button, Button,
Checkbox,
Flex, Flex,
FormControl, FormControl,
FormErrorMessage, FormErrorMessage,
@ -19,66 +20,27 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models'; import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models'; import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import BaseModelSelect from './Fields/BaseModelSelect'; import BaseModelSelect from './Fields/BaseModelSelect';
import BooleanSelect from './Fields/BooleanSelect';
import ModelFormatSelect from './Fields/ModelFormatSelect';
import ModelTypeSelect from './Fields/ModelTypeSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect'; import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect'; import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import RepoVariantSelect from './Fields/RepoVariantSelect';
export const ModelEdit = () => { export const ModelEdit = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation(); const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
const { t } = useTranslation(); const { t } = useTranslation();
// const modelData = useMemo(() => {
// if (!data) {
// return null;
// }
// const modelFormat = data.format;
// const modelType = data.type;
// if (modelType === 'main') {
// if (modelFormat === 'diffusers') {
// return data as DiffusersModelConfig;
// } else if (modelFormat === 'checkpoint') {
// return data as CheckpointModelConfig;
// }
// }
// switch (modelType) {
// case 'lora':
// return data as LoRAModelConfig;
// case 'embedding':
// return data as TextualInversionModelConfig;
// case 't2i_adapter':
// return data as T2IAdapterModelConfig;
// case 'ip_adapter':
// return data as IPAdapterModelConfig;
// case 'controlnet':
// return data as ControlNetModelConfig;
// case 'vae':
// return data as VAEModelConfig;
// default:
// return null;
// }
// }, [data]);
const { const {
register, register,
handleSubmit, handleSubmit,
control, control,
formState: { errors }, formState: { errors },
reset, reset,
watch,
} = useForm<UpdateModelArg['body']>({ } = useForm<UpdateModelArg['body']>({
defaultValues: { defaultValues: {
...data, ...data,
@ -86,10 +48,7 @@ export const ModelEdit = () => {
mode: 'onChange', mode: 'onChange',
}); });
const watchedModelType = watch('type'); const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
const watchedModelFormat = watch('format');
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => { (values) => {
if (!data?.key) { if (!data?.key) {
return; return;
@ -143,33 +102,31 @@ export const ModelEdit = () => {
return ( return (
<Flex flexDir="column" h="full"> <Flex flexDir="column" h="full">
<form onSubmit={handleSubmit(onSubmit)}> <form onSubmit={handleSubmit(onSubmit)}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}> <Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center"> <FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel> <FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
<Input <Input
{...register('name', { {...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters', validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
})} })}
size="lg" size="lg"
/> />
<Flex gap={2}> {errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
<Button size="sm" onClick={handleClickCancel}> </FormControl>
{t('common.cancel')} <Button size="sm" onClick={handleClickCancel}>
</Button> {t('common.cancel')}
<Button </Button>
size="sm" <Button
colorScheme="invokeYellow" size="sm"
onClick={handleSubmit(onSubmit)} colorScheme="invokeYellow"
isLoading={isSubmitting} onClick={handleSubmit(onSubmit)}
isDisabled={Boolean(Object.keys(errors).length)} isLoading={isSubmitting}
> isDisabled={Boolean(Object.keys(errors).length)}
{t('common.save')} >
</Button> {t('common.save')}
</Flex> </Button>
</Flex> </Flex>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<Flex flexDir="column" gap={3} mt="4"> <Flex flexDir="column" gap={3} mt="4">
<Flex> <Flex>
@ -184,76 +141,22 @@ export const ModelEdit = () => {
<Flex gap={4}> <Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel> <FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" /> <BaseModelSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl> </FormControl>
</Flex> </Flex>
<Flex gap={4}> {data.type === 'main' && data.format === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
<Flex gap={4}> <Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel> <FormLabel>{t('modelManager.variant')}</FormLabel>
<Input {...register('image_encoder_model_id')} /> <ModelVariantSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
</FormControl> </FormControl>
</Flex> </Flex>
)} )}

View File

@ -91,26 +91,19 @@ export const ModelView = () => {
<ModelAttrView label={t('modelManager.path')} value={modelData.path} /> <ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex> </Flex>
{modelData.type === 'main' && ( {modelData.type === 'main' && (
<> <Flex gap={2}>
<Flex gap={2}> {modelData.format === 'diffusers' && modelData.repo_variant && (
{modelData.format === 'diffusers' && ( <ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} /> )}
)} {modelData.format === 'checkpoint' && (
{modelData.format === 'checkpoint' && ( <>
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} /> <ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
)} <ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} /> <ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex> </>
<Flex gap={2}> )}
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} /> </Flex>
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
</Flex>
</>
)} )}
{modelData.type === 'ip_adapter' && ( {modelData.type === 'ip_adapter' && (
<Flex gap={2}> <Flex gap={2}>

View File

@ -1,7 +1,6 @@
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'; import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit'; import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector'; import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { JSONObject } from 'common/types';
import queryString from 'query-string'; import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema'; import type { operations, paths } from 'services/api/schema';
import type { import type {
@ -24,49 +23,33 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json']; body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
}; };
type UpdateModelMetadataArg = {
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>; type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
type DeleteMainModelArg = { type DeleteModelArg = {
key: string; key: string;
}; };
type DeleteModelResponse = void;
type DeleteMainModelResponse = void;
type ConvertMainModelResponse = type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json']; paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
type InstallModelArg = { type InstallModelArg = {
source: paths['/api/v2/models/install']['post']['parameters']['query']['source']; source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
config?: JSONObject;
// config: NonNullable<paths['/api/v2/models/install']['post']['requestBody']>['content']['application/json'];
}; };
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json']; type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
type ListImportModelsResponse = type ListModelInstallsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/install']['get']['responses']['200']['content']['application/json'];
type DeleteImportModelsResponse = type CancelModelInstallResponse =
paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json']; paths['/api/v2/models/install/{id}']['delete']['responses']['201']['content']['application/json'];
type PruneModelImportsResponse = type PruneCompletedModelInstallsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json']; paths['/api/v2/models/install']['delete']['responses']['200']['content']['application/json'];
export type ScanFolderResponse = export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json']; paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
@ -146,31 +129,7 @@ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
export const modelsApi = api.injectEndpoints({ export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({ updateModel: build.mutation<UpdateModelResponse, UpdateModelArg>({
query: (base_models) => {
const params: ListModelsArg = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return buildModelsUrl(`?${query}`);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getModelMetadata: build.query<GetModelMetadataResponse, string>({
query: (key) => {
return buildModelsUrl(`i/${key}/metadata`);
},
providesTags: ['Model'],
}),
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
query: ({ key, body }) => { query: ({ key, body }) => {
return { return {
url: buildModelsUrl(`i/${key}`), url: buildModelsUrl(`i/${key}`),
@ -180,28 +139,17 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
query: ({ key, body }) => {
return {
url: buildModelsUrl(`i/${key}/metadata`),
method: 'PATCH',
body: body,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({ installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, config, access_token }) => { query: ({ source }) => {
return { return {
url: buildModelsUrl('install'), url: buildModelsUrl('install'),
params: { source, access_token }, params: { source },
method: 'POST', method: 'POST',
body: config,
}; };
}, },
invalidatesTags: ['Model', 'ModelImports'], invalidatesTags: ['Model', 'ModelInstalls'],
}), }),
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({ deleteModels: build.mutation<DeleteModelResponse, DeleteModelArg>({
query: ({ key }) => { query: ({ key }) => {
return { return {
url: buildModelsUrl(`i/${key}`), url: buildModelsUrl(`i/${key}`),
@ -210,7 +158,7 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
convertMainModels: build.mutation<ConvertMainModelResponse, string>({ convertModel: build.mutation<ConvertMainModelResponse, string>({
query: (key) => { query: (key) => {
return { return {
url: buildModelsUrl(`convert/${key}`), url: buildModelsUrl(`convert/${key}`),
@ -253,6 +201,57 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
scanFolder: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => {
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
return {
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
};
},
}),
listModelInstalls: build.query<ListModelInstallsResponse, void>({
query: () => {
return {
url: buildModelsUrl('install'),
};
},
providesTags: ['ModelInstalls'],
}),
cancelModelInstall: build.mutation<CancelModelInstallResponse, number>({
query: (id) => {
return {
url: buildModelsUrl(`install/${id}`),
method: 'DELETE',
};
},
invalidatesTags: ['ModelInstalls'],
}),
pruneCompletedModelInstalls: build.mutation<PruneCompletedModelInstallsResponse, void>({
query: () => {
return {
url: buildModelsUrl('install'),
method: 'DELETE',
};
},
invalidatesTags: ['ModelInstalls'],
}),
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params: ListModelsArg = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return buildModelsUrl(`?${query}`);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }), query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'), providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
@ -313,40 +312,6 @@ export const modelsApi = api.injectEndpoints({
}); });
}, },
}), }),
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => {
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
return {
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
};
},
}),
getModelImports: build.query<ListImportModelsResponse, void>({
query: () => {
return {
url: buildModelsUrl(`import`),
};
},
providesTags: ['ModelImports'],
}),
deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
query: (id) => {
return {
url: buildModelsUrl(`import/${id}`),
method: 'DELETE',
};
},
invalidatesTags: ['ModelImports'],
}),
pruneModelImports: build.mutation<PruneModelImportsResponse, void>({
query: () => {
return {
url: buildModelsUrl('import'),
method: 'PATCH',
};
},
invalidatesTags: ['ModelImports'],
}),
}), }),
}); });
@ -360,16 +325,14 @@ export const {
useGetTextualInversionModelsQuery, useGetTextualInversionModelsQuery,
useGetVaeModelsQuery, useGetVaeModelsQuery,
useDeleteModelsMutation, useDeleteModelsMutation,
useUpdateModelsMutation, useUpdateModelMutation,
useInstallModelMutation, useInstallModelMutation,
useConvertMainModelsMutation, useConvertModelMutation,
useSyncModelsMutation, useSyncModelsMutation,
useLazyScanModelsQuery, useLazyScanFolderQuery,
useGetModelImportsQuery, useListModelInstallsQuery,
useGetModelMetadataQuery, useCancelModelInstallMutation,
useDeleteModelImportMutation, usePruneCompletedModelInstallsMutation,
usePruneModelImportsMutation,
useUpdateModelMetadataMutation,
} = modelsApi; } = modelsApi;
const upsertModelConfigs = ( const upsertModelConfigs = (

View File

@ -28,7 +28,7 @@ export const tagTypes = [
'InvocationCacheStatus', 'InvocationCacheStatus',
'Model', 'Model',
'ModelConfig', 'ModelConfig',
'ModelImports', 'ModelInstalls',
'T2IAdapterModel', 'T2IAdapterModel',
'MainModel', 'MainModel',
'VaeModel', 'VaeModel',

File diff suppressed because one or more lines are too long

View File

@ -43,14 +43,13 @@ export type ControlField = S['ControlField'];
// Model Configs // Model Configs
// TODO(MM2): Can we make key required in the pydantic model? // TODO(MM2): Can we make key required in the pydantic model?
export type LoRAModelConfig = S['LoRAConfig']; export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
// TODO(MM2): Can we rename this from Vae -> VAE // TODO(MM2): Can we rename this from Vae -> VAE
export type VAEModelConfig = S['VaeCheckpointConfig'] | S['VaeDiffusersConfig']; export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterConfig']; export type IPAdapterModelConfig = S['IPAdapterConfig'];
// TODO(MM2): Can we rename this to T2IAdapterConfig export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type T2IAdapterModelConfig = S['T2IConfig']; export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
export type TextualInversionModelConfig = S['TextualInversionConfig'];
export type DiffusersModelConfig = S['MainDiffusersConfig']; export type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig']; export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig']; type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];