feat(ui): update UI to use new model config backend

- Update all queries
- Remove Advanced Add
- Removed un-editable, internal-only model attributes from model edit UI (e.g. format, repo variant, model type)
- Update model tags so the list refreshes when a model installs
- Rename some queries, components, variables, types to match backend
- Fix divide-by-zero in install queue
This commit is contained in:
psychedelicious 2024-03-05 19:04:13 +11:00
parent 48119d9010
commit 99407c899f
30 changed files with 993 additions and 1824 deletions

View File

@ -34,13 +34,13 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
return;
}
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
if (!metadata || !metadata.default_settings) {
if (!modelConfig || !modelConfig.default_settings) {
return;
}
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
if (vae) {
// we store this as "default" within default settings

View File

@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { bytes, total_bytes, id } = action.payload.data;
dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.bytes = bytes;
@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id } = action.payload.data;
dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'completed';
@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft;
})
);
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
dispatch(api.util.invalidateTags(['Model']));
},
});
@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
const { id, error, error_type } = action.payload.data;
dispatch(
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'error';

View File

@ -1,228 +0,0 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { isNil, omitBy } from 'lodash-es';
import { useCallback, useEffect } from 'react';
import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
export const AdvancedImport = () => {
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const { t } = useTranslation();
const {
register,
handleSubmit,
control,
formState: { errors },
setValue,
resetField,
reset,
watch,
} = useForm<AnyModelConfig>({
defaultValues: {
name: '',
base: 'sd-1',
type: 'main',
path: '',
description: '',
format: 'diffusers',
vae: '',
variant: 'normal',
},
mode: 'onChange',
});
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
(values) => {
installModel({
source: values.path,
config: omitBy(values, isNil),
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelAdded', {
modelName: values.name,
}),
status: 'success',
})
)
);
reset();
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddFailed'),
status: 'error',
})
)
);
}
});
},
[installModel, dispatch, t, reset]
);
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
useEffect(() => {
if (watchedModelType === 'main') {
setValue('format', 'diffusers');
setValue('repo_variant', '');
setValue('variant', 'normal');
}
if (watchedModelType === 'lora') {
setValue('format', 'lycoris');
} else if (watchedModelType === 'embedding') {
setValue('format', 'embedding_file');
} else if (watchedModelType === 'ip_adapter') {
setValue('format', 'invokeai');
} else {
setValue('format', 'diffusers');
}
resetField('upcast_attention');
resetField('ztsnr_training');
resetField('vae');
resetField('config');
resetField('prediction_type');
resetField('image_encoder_model_id');
}, [watchedModelType, resetField, setValue]);
return (
<ScrollableContent>
<form onSubmit={handleSubmit(onSubmit)}>
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
<Flex alignItems="flex-end" gap="4">
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
</FormControl>
<Text px="2" fontSize="xs" textAlign="center">
{t('modelManager.advancedImportInfo')}
</Text>
</Flex>
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
<FormControl isInvalid={Boolean(errors.name)}>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.name')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
})}
/>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
<Flex>
<FormControl>
<Flex direction="column" width="full">
<FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea size="sm" {...register('description')} />
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</Flex>
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
<Input {...register('image_encoder_model_id')} />
</FormControl>
</Flex>
)}
<Button mt={2} type="submit">
{t('modelManager.addModel')}
</Button>
</Flex>
</Flex>
</form>
</ScrollableContent>
);
};

View File

@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
location: string;
};
export const SimpleImport = () => {
export const InstallModelForm = () => {
const dispatch = useAppDispatch();
const [installModel, { isLoading }] = useInstallModelMutation();

View File

@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { useCallback, useMemo } from 'react';
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
import { ImportQueueItem } from './ImportQueueItem';
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
export const ImportQueue = () => {
export const ModelInstallQueue = () => {
const dispatch = useAppDispatch();
const { data } = useGetModelImportsQuery();
const { data } = useListModelInstallsQuery();
const [pruneModelImports] = usePruneModelImportsMutation();
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
const pruneQueue = useCallback(() => {
pruneModelImports()
const pruneCompletedModelInstalls = useCallback(() => {
_pruneCompletedModelInstalls()
.unwrap()
.then((_) => {
dispatch(
@ -41,7 +41,7 @@ export const ImportQueue = () => {
);
}
});
}, [pruneModelImports, dispatch]);
}, [_pruneCompletedModelInstalls, dispatch]);
const pruneAvailable = useMemo(() => {
return data?.some(
@ -53,14 +53,19 @@ export const ImportQueue = () => {
<Flex flexDir="column" p={3} h="full">
<Flex justifyContent="space-between" alignItems="center">
<Text>{t('modelManager.importQueue')}</Text>
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
<Button
size="sm"
isDisabled={!pruneAvailable}
onClick={pruneCompletedModelInstalls}
tooltip={t('modelManager.pruneTooltip')}
>
{t('modelManager.prune')}
</Button>
</Flex>
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
<ScrollableContent>
<Flex flexDir="column-reverse" gap="2">
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
</Flex>
</ScrollableContent>
</Box>

View File

@ -6,17 +6,24 @@ import type { ModelInstallStatus } from 'services/api/types';
const STATUSES = {
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
error: { colorScheme: 'red', translationKey: 'queue.failed' },
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
};
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
const ModelInstallQueueBadge = ({
status,
errorReason,
}: {
status?: ModelInstallStatus;
errorReason?: string | null;
}) => {
const { t } = useTranslation();
if (!status || !Object.keys(STATUSES).includes(status)) {
return <></>;
return null;
}
return (
@ -25,4 +32,4 @@ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus
</Tooltip>
);
};
export default memo(ImportQueueBadge);
export default memo(ModelInstallQueueBadge);

View File

@ -3,15 +3,16 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { isNil } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import { PiXBold } from 'react-icons/pi';
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
import ImportQueueBadge from './ImportQueueBadge';
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
type ModelListItemProps = {
model: ModelInstallJob;
installJob: ModelInstallJob;
};
const formatBytes = (bytes: number) => {
@ -26,26 +27,26 @@ const formatBytes = (bytes: number) => {
return `${bytes.toFixed(2)} ${units[i]}`;
};
export const ImportQueueItem = (props: ModelListItemProps) => {
const { model } = props;
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
const { installJob } = props;
const dispatch = useAppDispatch();
const [deleteImportModel] = useDeleteModelImportMutation();
const [deleteImportModel] = useCancelModelInstallMutation();
const source = useMemo(() => {
if (model.source.type === 'hf') {
return model.source as HFModelSource;
} else if (model.source.type === 'local') {
return model.source as LocalModelSource;
} else if (model.source.type === 'url') {
return model.source as URLModelSource;
if (installJob.source.type === 'hf') {
return installJob.source as HFModelSource;
} else if (installJob.source.type === 'local') {
return installJob.source as LocalModelSource;
} else if (installJob.source.type === 'url') {
return installJob.source as URLModelSource;
} else {
return model.source as LocalModelSource;
return installJob.source as LocalModelSource;
}
}, [model.source]);
}, [installJob.source]);
const handleDeleteModelImport = useCallback(() => {
deleteImportModel(model.id)
deleteImportModel(installJob.id)
.unwrap()
.then((_) => {
dispatch(
@ -69,7 +70,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
);
}
});
}, [deleteImportModel, model, dispatch]);
}, [deleteImportModel, installJob, dispatch]);
const modelName = useMemo(() => {
switch (source.type) {
@ -85,19 +86,23 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
}, [source]);
const progressValue = useMemo(() => {
if (model.bytes === undefined || model.total_bytes === undefined) {
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
return null;
}
if (installJob.total_bytes === 0) {
return 0;
}
return (model.bytes / model.total_bytes) * 100;
}, [model.bytes, model.total_bytes]);
return (installJob.bytes / installJob.total_bytes) * 100;
}, [installJob.bytes, installJob.total_bytes]);
const progressString = useMemo(() => {
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
return '';
}
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
}, [model.bytes, model.total_bytes, model.status]);
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
return (
<Flex gap="2" w="full" alignItems="center">
@ -109,19 +114,21 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
<Flex flexDir="column" flex={1}>
<Tooltip label={progressString}>
<Progress
value={progressValue}
isIndeterminate={progressValue === undefined}
value={progressValue ?? 0}
isIndeterminate={progressValue === null}
aria-label={t('accessibility.invokeProgressBar')}
h={2}
/>
</Tooltip>
</Flex>
<Box minW="100px" textAlign="center">
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
</Box>
<Box minW="20px">
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
{(installJob.status === 'downloading' ||
installJob.status === 'waiting' ||
installJob.status === 'running') && (
<IconButton
isRound={true}
size="xs"

View File

@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
import { ScanModelsResults } from './ScanModelsResults';
import { ScanModelsResults } from './ScanFolderResults';
export const ScanModelsForm = () => {
const [scanPath, setScanPath] = useState('');
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
const handleSubmitScan = useCallback(async () => {
_scanModels({ scan_path: scanPath }).catch((error) => {
const scanFolder = useCallback(async () => {
_scanFolder({ scan_path: scanPath }).catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
}, [_scanModels, scanPath]);
}, [_scanFolder, scanPath]);
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setScanPath(e.target.value);
@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
<Input value={scanPath} onChange={handleSetScanPath} />
</Flex>
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
{t('modelManager.scanFolder')}
</Button>
</Flex>

View File

@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
import { ScanModelResultItem } from './ScanModelResultItem';
import { ScanModelResultItem } from './ScanFolderResultItem';
type ScanModelResultsProps = {
results: ScanFolderResponse;

View File

@ -1,12 +1,11 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next';
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
import { SimpleImport } from './AddModelPanel/SimpleImport';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
export const ImportModels = () => {
export const InstallModels = () => {
const { t } = useTranslation();
return (
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
@ -17,15 +16,11 @@ export const ImportModels = () => {
<Tabs variant="collapse" height="100%">
<TabList>
<Tab>{t('common.simple')}</Tab>
<Tab>{t('modelManager.advanced')}</Tab>
<Tab>{t('modelManager.scan')}</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanel>
<SimpleImport />
</TabPanel>
<TabPanel height="100%">
<AdvancedImport />
<InstallModelForm />
</TabPanel>
<TabPanel height="100%">
<ScanModelsForm />
@ -34,7 +29,7 @@ export const ImportModels = () => {
</Tabs>
</Box>
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
<ImportQueue />
<ModelInstallQueue />
</Box>
</Flex>
);

View File

@ -5,7 +5,7 @@ import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoFilter } from 'react-icons/io5';
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
const MODEL_TYPE_LABELS: { [key: string]: string } = {
main: 'Main',
lora: 'LoRA',
embedding: 'Textual Inversion',

View File

@ -1,14 +1,14 @@
import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ImportModels } from './ImportModels';
import { InstallModels } from './InstallModels';
import { Model } from './ModelPanel/Model';
export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return (
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
</Box>
);
};

View File

@ -5,7 +5,7 @@ import Loading from 'common/components/Loading/Loading';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { isNil } from 'lodash-es';
import { useMemo } from 'react';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
@ -24,7 +24,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
export const DefaultSettings = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
useAppSelector(initialStatesSelector);

View File

@ -8,7 +8,7 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import { IoPencil } from 'react-icons/io5';
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
import { useUpdateModelMutation } from 'services/api/endpoints/models';
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
import { DefaultCfgScale } from './DefaultCfgScale';
@ -41,7 +41,7 @@ export const DefaultSettingsForm = ({
const { t } = useTranslation();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
const [updateModel, { isLoading }] = useUpdateModelMutation();
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
defaultValues: defaultSettingsDefaults,
@ -62,7 +62,7 @@ export const DefaultSettingsForm = ({
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
};
editModelMetadata({
updateModel({
key: selectedModelKey,
body: { default_settings: body },
})
@ -90,7 +90,7 @@ export const DefaultSettingsForm = ({
}
});
},
[selectedModelKey, dispatch, editModelMetadata, t]
[selectedModelKey, dispatch, updateModel, t]
);
return (

View File

@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
@ -14,8 +14,12 @@ const options: ComboboxOption[] = [
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field } = useController(props);
type Props = {
control: Control<UpdateModelArg['body']>;
};
const BaseModelSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'base' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'true', label: 'True' },
{ value: 'false', label: 'False' },
];
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(BooleanSelect);

View File

@ -1,47 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
const { field, formState } = useController(props);
const type = useWatch({ control: props.control, name: 'type' });
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
const options: ComboboxOption[] = useMemo(() => {
const modelType = type || formState.defaultValues?.type;
if (modelType === 'lora') {
return [
{ value: 'lycoris', label: 'LyCORIS' },
{ value: 'diffusers', label: 'Diffusers' },
];
} else if (modelType === 'embedding') {
return [
{ value: 'embedding_file', label: 'Embedding File' },
{ value: 'embedding_folder', label: 'Embedding Folder' },
];
} else if (modelType === 'ip_adapter') {
return [{ value: 'invokeai', label: 'invokeai' }];
} else {
return [
{ value: 'diffusers', label: 'Diffusers' },
{ value: 'checkpoint', label: 'Checkpoint' },
];
}
}, [type, formState.defaultValues?.type]);
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelFormatSelect);

View File

@ -1,32 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
] as const;
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(ModelTypeSelect);

View File

@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'normal', label: 'Normal' },
@ -12,8 +12,12 @@ const options: ComboboxOption[] = [
{ value: 'depth', label: 'Depth' },
];
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
type Props = {
control: Control<UpdateModelArg['body']>;
};
const ModelVariantSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'variant' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import type { Control } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
@ -13,8 +13,12 @@ const options: ComboboxOption[] = [
{ value: 'sample', label: 'sample' },
];
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
type Props = {
control: Control<UpdateModelArg['body']>;
};
const PredictionTypeSelect = ({ control }: Props) => {
const { field } = useController({ control, name: 'prediction_type' });
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@ -1,27 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
const options: ComboboxOption[] = [
{ value: 'none', label: '-' },
{ value: 'fp16', label: 'fp16' },
{ value: 'fp32', label: 'fp32' },
];
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
const { field } = useController(props);
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
const onChange = useCallback<ComboboxOnChange>(
(v) => {
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
},
[field]
);
return <Combobox value={value} options={options} onChange={onChange} />;
};
export default typedMemo(RepoVariantSelect);

View File

@ -2,16 +2,16 @@ import { Flex } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
export const ModelMetadata = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
return (
<>
<Flex flexDir="column" height="full" gap="3">
<DataViewer label="metadata" data={metadata || {}} />
<DataViewer label="metadata" data={data?.source_api_response || {}} />
</Flex>
</>
);

View File

@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import { useConvertModelMutation } from 'services/api/endpoints/models';
import type { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps {
@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
const { model } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [convertModel, { isLoading }] = useConvertModelMutation();
const { isOpen, onOpen, onClose } = useDisclosure();
const modelConvertHandler = useCallback(() => {

View File

@ -1,5 +1,6 @@
import {
Button,
Checkbox,
Flex,
FormControl,
FormErrorMessage,
@ -19,66 +20,27 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
import BaseModelSelect from './Fields/BaseModelSelect';
import BooleanSelect from './Fields/BooleanSelect';
import ModelFormatSelect from './Fields/ModelFormatSelect';
import ModelTypeSelect from './Fields/ModelTypeSelect';
import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
import RepoVariantSelect from './Fields/RepoVariantSelect';
export const ModelEdit = () => {
const dispatch = useAppDispatch();
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
const { t } = useTranslation();
// const modelData = useMemo(() => {
// if (!data) {
// return null;
// }
// const modelFormat = data.format;
// const modelType = data.type;
// if (modelType === 'main') {
// if (modelFormat === 'diffusers') {
// return data as DiffusersModelConfig;
// } else if (modelFormat === 'checkpoint') {
// return data as CheckpointModelConfig;
// }
// }
// switch (modelType) {
// case 'lora':
// return data as LoRAModelConfig;
// case 'embedding':
// return data as TextualInversionModelConfig;
// case 't2i_adapter':
// return data as T2IAdapterModelConfig;
// case 'ip_adapter':
// return data as IPAdapterModelConfig;
// case 'controlnet':
// return data as ControlNetModelConfig;
// case 'vae':
// return data as VAEModelConfig;
// default:
// return null;
// }
// }, [data]);
const {
register,
handleSubmit,
control,
formState: { errors },
reset,
watch,
} = useForm<UpdateModelArg['body']>({
defaultValues: {
...data,
@ -86,10 +48,7 @@ export const ModelEdit = () => {
mode: 'onChange',
});
const watchedModelType = watch('type');
const watchedModelFormat = watch('format');
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
(values) => {
if (!data?.key) {
return;
@ -143,33 +102,31 @@ export const ModelEdit = () => {
return (
<Flex flexDir="column" h="full">
<form onSubmit={handleSubmit(onSubmit)}>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
<Input
{...register('name', {
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
})}
size="lg"
/>
<Flex gap={2}>
<Button size="sm" onClick={handleClickCancel}>
{t('common.cancel')}
</Button>
<Button
size="sm"
colorScheme="invokeYellow"
onClick={handleSubmit(onSubmit)}
isLoading={isSubmitting}
isDisabled={Boolean(Object.keys(errors).length)}
>
{t('common.save')}
</Button>
</Flex>
</Flex>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
</FormControl>
<Button size="sm" onClick={handleClickCancel}>
{t('common.cancel')}
</Button>
<Button
size="sm"
colorScheme="invokeYellow"
onClick={handleSubmit(onSubmit)}
isLoading={isSubmitting}
isDisabled={Boolean(Object.keys(errors).length)}
>
{t('common.save')}
</Button>
</Flex>
<Flex flexDir="column" gap={3} mt="4">
<Flex>
@ -184,76 +141,22 @@ export const ModelEdit = () => {
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={control} name="base" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.modelType')}</FormLabel>
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
<BaseModelSelect control={control} />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('common.format')}</FormLabel>
<ModelFormatSelect control={control} name="format" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
<FormLabel>{t('modelManager.path')}</FormLabel>
<Input
{...register('path', {
validate: (value) => value.trim().length > 0 || 'Must provide a path',
})}
/>
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
</FormControl>
</Flex>
{watchedModelType === 'main' && (
<>
<Flex gap={4}>
{watchedModelFormat === 'diffusers' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
</FormControl>
)}
{watchedModelFormat === 'checkpoint' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
<Input {...register('config')} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
</FormControl>
</Flex>
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
<Input {...register('vae')} />
</FormControl>
</Flex>
</>
)}
{watchedModelType === 'ip_adapter' && (
{data.type === 'main' && data.format === 'checkpoint' && (
<Flex gap={4}>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
<Input {...register('image_encoder_model_id')} />
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
<PredictionTypeSelect control={control} />
</FormControl>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
<Checkbox {...register('upcast_attention')} />
</FormControl>
</Flex>
)}

View File

@ -91,26 +91,19 @@ export const ModelView = () => {
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
</Flex>
{modelData.type === 'main' && (
<>
<Flex gap={2}>
{modelData.format === 'diffusers' && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
)}
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
</Flex>
</>
<Flex gap={2}>
{modelData.format === 'diffusers' && modelData.repo_variant && (
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
)}
{modelData.format === 'checkpoint' && (
<>
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
</>
)}
</Flex>
)}
{modelData.type === 'ip_adapter' && (
<Flex gap={2}>

View File

@ -1,7 +1,6 @@
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import type { JSONObject } from 'common/types';
import queryString from 'query-string';
import type { operations, paths } from 'services/api/schema';
import type {
@ -24,49 +23,33 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelMetadataArg = {
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type GetModelMetadataResponse =
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
type DeleteMainModelArg = {
type DeleteModelArg = {
key: string;
};
type DeleteMainModelResponse = void;
type DeleteModelResponse = void;
type ConvertMainModelResponse =
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
type InstallModelArg = {
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
config?: JSONObject;
// config: NonNullable<paths['/api/v2/models/install']['post']['requestBody']>['content']['application/json'];
};
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
type ListImportModelsResponse =
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
type ListModelInstallsResponse =
paths['/api/v2/models/install']['get']['responses']['200']['content']['application/json'];
type DeleteImportModelsResponse =
paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json'];
type CancelModelInstallResponse =
paths['/api/v2/models/install/{id}']['delete']['responses']['201']['content']['application/json'];
type PruneModelImportsResponse =
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
type PruneCompletedModelInstallsResponse =
paths['/api/v2/models/install']['delete']['responses']['200']['content']['application/json'];
export type ScanFolderResponse =
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
@ -146,31 +129,7 @@ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params: ListModelsArg = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return buildModelsUrl(`?${query}`);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getModelMetadata: build.query<GetModelMetadataResponse, string>({
query: (key) => {
return buildModelsUrl(`i/${key}/metadata`);
},
providesTags: ['Model'],
}),
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
updateModel: build.mutation<UpdateModelResponse, UpdateModelArg>({
query: ({ key, body }) => {
return {
url: buildModelsUrl(`i/${key}`),
@ -180,28 +139,17 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
query: ({ key, body }) => {
return {
url: buildModelsUrl(`i/${key}/metadata`),
method: 'PATCH',
body: body,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source, config, access_token }) => {
query: ({ source }) => {
return {
url: buildModelsUrl('install'),
params: { source, access_token },
params: { source },
method: 'POST',
body: config,
};
},
invalidatesTags: ['Model', 'ModelImports'],
invalidatesTags: ['Model', 'ModelInstalls'],
}),
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
deleteModels: build.mutation<DeleteModelResponse, DeleteModelArg>({
query: ({ key }) => {
return {
url: buildModelsUrl(`i/${key}`),
@ -210,7 +158,7 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
convertMainModels: build.mutation<ConvertMainModelResponse, string>({
convertModel: build.mutation<ConvertMainModelResponse, string>({
query: (key) => {
return {
url: buildModelsUrl(`convert/${key}`),
@ -253,6 +201,57 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
scanFolder: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => {
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
return {
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
};
},
}),
listModelInstalls: build.query<ListModelInstallsResponse, void>({
query: () => {
return {
url: buildModelsUrl('install'),
};
},
providesTags: ['ModelInstalls'],
}),
cancelModelInstall: build.mutation<CancelModelInstallResponse, number>({
query: (id) => {
return {
url: buildModelsUrl(`install/${id}`),
method: 'DELETE',
};
},
invalidatesTags: ['ModelInstalls'],
}),
pruneCompletedModelInstalls: build.mutation<PruneCompletedModelInstallsResponse, void>({
query: () => {
return {
url: buildModelsUrl('install'),
method: 'DELETE',
};
},
invalidatesTags: ['ModelInstalls'],
}),
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
query: (base_models) => {
const params: ListModelsArg = {
model_type: 'main',
base_models,
};
const query = queryString.stringify(params, { arrayFormat: 'none' });
return buildModelsUrl(`?${query}`);
},
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
queryFulfilled.then(({ data }) => {
upsertModelConfigs(data, dispatch);
});
},
}),
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
@ -313,40 +312,6 @@ export const modelsApi = api.injectEndpoints({
});
},
}),
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
query: (arg) => {
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
return {
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
};
},
}),
getModelImports: build.query<ListImportModelsResponse, void>({
query: () => {
return {
url: buildModelsUrl(`import`),
};
},
providesTags: ['ModelImports'],
}),
deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
query: (id) => {
return {
url: buildModelsUrl(`import/${id}`),
method: 'DELETE',
};
},
invalidatesTags: ['ModelImports'],
}),
pruneModelImports: build.mutation<PruneModelImportsResponse, void>({
query: () => {
return {
url: buildModelsUrl('import'),
method: 'PATCH',
};
},
invalidatesTags: ['ModelImports'],
}),
}),
});
@ -360,16 +325,14 @@ export const {
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useDeleteModelsMutation,
useUpdateModelsMutation,
useUpdateModelMutation,
useInstallModelMutation,
useConvertMainModelsMutation,
useConvertModelMutation,
useSyncModelsMutation,
useLazyScanModelsQuery,
useGetModelImportsQuery,
useGetModelMetadataQuery,
useDeleteModelImportMutation,
usePruneModelImportsMutation,
useUpdateModelMetadataMutation,
useLazyScanFolderQuery,
useListModelInstallsQuery,
useCancelModelInstallMutation,
usePruneCompletedModelInstallsMutation,
} = modelsApi;
const upsertModelConfigs = (

View File

@ -28,7 +28,7 @@ export const tagTypes = [
'InvocationCacheStatus',
'Model',
'ModelConfig',
'ModelImports',
'ModelInstalls',
'T2IAdapterModel',
'MainModel',
'VaeModel',

File diff suppressed because one or more lines are too long

View File

@ -43,14 +43,13 @@ export type ControlField = S['ControlField'];
// Model Configs
// TODO(MM2): Can we make key required in the pydantic model?
export type LoRAModelConfig = S['LoRAConfig'];
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
// TODO(MM2): Can we rename this from Vae -> VAE
export type VAEModelConfig = S['VaeCheckpointConfig'] | S['VaeDiffusersConfig'];
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterConfig'];
// TODO(MM2): Can we rename this to T2IAdapterConfig
export type T2IAdapterModelConfig = S['T2IConfig'];
export type TextualInversionModelConfig = S['TextualInversionConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
export type DiffusersModelConfig = S['MainDiffusersConfig'];
export type CheckpointModelConfig = S['MainCheckpointConfig'];
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];