mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): update UI to use new model config backend
- Update all queries - Remove Advanced Add - Removed un-editable, internal-only model attributes from model edit UI (e.g. format, repo variant, model type) - Update model tags so the list refreshes when a model installs - Rename some queries, components, variables, types to match backend - Fix divide-by-zero in install queue
This commit is contained in:
parent
48119d9010
commit
99407c899f
@ -34,13 +34,13 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const metadata = await dispatch(modelsApi.endpoints.getModelMetadata.initiate(currentModel.key)).unwrap();
|
const modelConfig = await dispatch(modelsApi.endpoints.getModelConfig.initiate(currentModel.key)).unwrap();
|
||||||
|
|
||||||
if (!metadata || !metadata.default_settings) {
|
if (!modelConfig || !modelConfig.default_settings) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = metadata.default_settings;
|
const { vae, vae_precision, cfg_scale, cfg_rescale_multiplier, steps, scheduler } = modelConfig.default_settings;
|
||||||
|
|
||||||
if (vae) {
|
if (vae) {
|
||||||
// we store this as "default" within default settings
|
// we store this as "default" within default settings
|
||||||
|
@ -14,7 +14,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
const { bytes, total_bytes, id } = action.payload.data;
|
const { bytes, total_bytes, id } = action.payload.data;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
if (modelImport) {
|
if (modelImport) {
|
||||||
modelImport.bytes = bytes;
|
modelImport.bytes = bytes;
|
||||||
@ -33,7 +33,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
const { id } = action.payload.data;
|
const { id } = action.payload.data;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
if (modelImport) {
|
if (modelImport) {
|
||||||
modelImport.status = 'completed';
|
modelImport.status = 'completed';
|
||||||
@ -41,7 +41,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
return draft;
|
return draft;
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
dispatch(api.util.invalidateTags([{ type: 'ModelConfig' }]));
|
dispatch(api.util.invalidateTags(['Model']));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
|||||||
const { id, error, error_type } = action.payload.data;
|
const { id, error, error_type } = action.payload.data;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
modelsApi.util.updateQueryData('getModelImports', undefined, (draft) => {
|
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
|
||||||
const modelImport = draft.find((m) => m.id === id);
|
const modelImport = draft.find((m) => m.id === id);
|
||||||
if (modelImport) {
|
if (modelImport) {
|
||||||
modelImport.status = 'error';
|
modelImport.status = 'error';
|
||||||
|
@ -1,228 +0,0 @@
|
|||||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text, Textarea } from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
|
||||||
import BaseModelSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BaseModelSelect';
|
|
||||||
import BooleanSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/BooleanSelect';
|
|
||||||
import ModelFormatSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelFormatSelect';
|
|
||||||
import ModelTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelTypeSelect';
|
|
||||||
import ModelVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/ModelVariantSelect';
|
|
||||||
import PredictionTypeSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/PredictionTypeSelect';
|
|
||||||
import RepoVariantSelect from 'features/modelManagerV2/subpanels/ModelPanel/Fields/RepoVariantSelect';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { isNil, omitBy } from 'lodash-es';
|
|
||||||
import { useCallback, useEffect } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
export const AdvancedImport = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const [installModel] = useInstallModelMutation();
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const {
|
|
||||||
register,
|
|
||||||
handleSubmit,
|
|
||||||
control,
|
|
||||||
formState: { errors },
|
|
||||||
setValue,
|
|
||||||
resetField,
|
|
||||||
reset,
|
|
||||||
watch,
|
|
||||||
} = useForm<AnyModelConfig>({
|
|
||||||
defaultValues: {
|
|
||||||
name: '',
|
|
||||||
base: 'sd-1',
|
|
||||||
type: 'main',
|
|
||||||
path: '',
|
|
||||||
description: '',
|
|
||||||
format: 'diffusers',
|
|
||||||
vae: '',
|
|
||||||
variant: 'normal',
|
|
||||||
},
|
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
|
||||||
(values) => {
|
|
||||||
installModel({
|
|
||||||
source: values.path,
|
|
||||||
config: omitBy(values, isNil),
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelAdded', {
|
|
||||||
modelName: values.name,
|
|
||||||
}),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
reset();
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.modelAddFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[installModel, dispatch, t, reset]
|
|
||||||
);
|
|
||||||
|
|
||||||
const watchedModelType = watch('type');
|
|
||||||
const watchedModelFormat = watch('format');
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (watchedModelType === 'main') {
|
|
||||||
setValue('format', 'diffusers');
|
|
||||||
setValue('repo_variant', '');
|
|
||||||
setValue('variant', 'normal');
|
|
||||||
}
|
|
||||||
if (watchedModelType === 'lora') {
|
|
||||||
setValue('format', 'lycoris');
|
|
||||||
} else if (watchedModelType === 'embedding') {
|
|
||||||
setValue('format', 'embedding_file');
|
|
||||||
} else if (watchedModelType === 'ip_adapter') {
|
|
||||||
setValue('format', 'invokeai');
|
|
||||||
} else {
|
|
||||||
setValue('format', 'diffusers');
|
|
||||||
}
|
|
||||||
resetField('upcast_attention');
|
|
||||||
resetField('ztsnr_training');
|
|
||||||
resetField('vae');
|
|
||||||
resetField('config');
|
|
||||||
resetField('prediction_type');
|
|
||||||
resetField('image_encoder_model_id');
|
|
||||||
}, [watchedModelType, resetField, setValue]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<ScrollableContent>
|
|
||||||
<form onSubmit={handleSubmit(onSubmit)}>
|
|
||||||
<Flex flexDirection="column" gap={4} width="100%" pb={10}>
|
|
||||||
<Flex alignItems="flex-end" gap="4">
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
|
||||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
|
||||||
</FormControl>
|
|
||||||
<Text px="2" fontSize="xs" textAlign="center">
|
|
||||||
{t('modelManager.advancedImportInfo')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex p={4} borderRadius={4} bg="base.850" height="100%" direction="column" gap="3">
|
|
||||||
<FormControl isInvalid={Boolean(errors.name)}>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('name', {
|
|
||||||
validate: (value) => value.trim().length >= 3 || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
<Flex>
|
|
||||||
<FormControl>
|
|
||||||
<Flex direction="column" width="full">
|
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
|
||||||
<Textarea size="sm" {...register('description')} />
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
|
||||||
<BaseModelSelect control={control} name="base" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('common.format')}</FormLabel>
|
|
||||||
<ModelFormatSelect control={control} name="format" />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
|
||||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
{watchedModelType === 'main' && (
|
|
||||||
<>
|
|
||||||
<Flex gap={4}>
|
|
||||||
{watchedModelFormat === 'diffusers' && (
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
|
||||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
{watchedModelFormat === 'checkpoint' && (
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
|
||||||
<Input {...register('config')} />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
|
||||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
|
||||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
|
||||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
|
||||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
|
||||||
<Input {...register('vae')} />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{watchedModelType === 'ip_adapter' && (
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
|
||||||
<Input {...register('image_encoder_model_id')} />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
<Button mt={2} type="submit">
|
|
||||||
{t('modelManager.addModel')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
</ScrollableContent>
|
|
||||||
);
|
|
||||||
};
|
|
@ -12,7 +12,7 @@ type SimpleImportModelConfig = {
|
|||||||
location: string;
|
location: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const SimpleImport = () => {
|
export const InstallModelForm = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const [installModel, { isLoading }] = useInstallModelMutation();
|
const [installModel, { isLoading }] = useInstallModelMutation();
|
@ -5,19 +5,19 @@ import { addToast } from 'features/system/store/systemSlice';
|
|||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { useGetModelImportsQuery, usePruneModelImportsMutation } from 'services/api/endpoints/models';
|
import { useListModelInstallsQuery, usePruneCompletedModelInstallsMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { ImportQueueItem } from './ImportQueueItem';
|
import { ModelInstallQueueItem } from './ModelInstallQueueItem';
|
||||||
|
|
||||||
export const ImportQueue = () => {
|
export const ModelInstallQueue = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const { data } = useGetModelImportsQuery();
|
const { data } = useListModelInstallsQuery();
|
||||||
|
|
||||||
const [pruneModelImports] = usePruneModelImportsMutation();
|
const [_pruneCompletedModelInstalls] = usePruneCompletedModelInstallsMutation();
|
||||||
|
|
||||||
const pruneQueue = useCallback(() => {
|
const pruneCompletedModelInstalls = useCallback(() => {
|
||||||
pruneModelImports()
|
_pruneCompletedModelInstalls()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -41,7 +41,7 @@ export const ImportQueue = () => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}, [pruneModelImports, dispatch]);
|
}, [_pruneCompletedModelInstalls, dispatch]);
|
||||||
|
|
||||||
const pruneAvailable = useMemo(() => {
|
const pruneAvailable = useMemo(() => {
|
||||||
return data?.some(
|
return data?.some(
|
||||||
@ -53,14 +53,19 @@ export const ImportQueue = () => {
|
|||||||
<Flex flexDir="column" p={3} h="full">
|
<Flex flexDir="column" p={3} h="full">
|
||||||
<Flex justifyContent="space-between" alignItems="center">
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
<Text>{t('modelManager.importQueue')}</Text>
|
<Text>{t('modelManager.importQueue')}</Text>
|
||||||
<Button size="sm" isDisabled={!pruneAvailable} onClick={pruneQueue} tooltip={t('modelManager.pruneTooltip')}>
|
<Button
|
||||||
|
size="sm"
|
||||||
|
isDisabled={!pruneAvailable}
|
||||||
|
onClick={pruneCompletedModelInstalls}
|
||||||
|
tooltip={t('modelManager.pruneTooltip')}
|
||||||
|
>
|
||||||
{t('modelManager.prune')}
|
{t('modelManager.prune')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
<Box mt={3} layerStyle="first" p={3} borderRadius="base" w="full" h="full">
|
||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex flexDir="column-reverse" gap="2">
|
<Flex flexDir="column-reverse" gap="2">
|
||||||
{data?.map((model) => <ImportQueueItem key={model.id} model={model} />)}
|
{data?.map((model) => <ModelInstallQueueItem key={model.id} installJob={model} />)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
</Box>
|
</Box>
|
@ -6,17 +6,24 @@ import type { ModelInstallStatus } from 'services/api/types';
|
|||||||
const STATUSES = {
|
const STATUSES = {
|
||||||
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
|
waiting: { colorScheme: 'cyan', translationKey: 'queue.pending' },
|
||||||
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
downloading: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||||
|
downloads_done: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||||
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
running: { colorScheme: 'yellow', translationKey: 'queue.in_progress' },
|
||||||
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
|
completed: { colorScheme: 'green', translationKey: 'queue.completed' },
|
||||||
error: { colorScheme: 'red', translationKey: 'queue.failed' },
|
error: { colorScheme: 'red', translationKey: 'queue.failed' },
|
||||||
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
cancelled: { colorScheme: 'orange', translationKey: 'queue.canceled' },
|
||||||
};
|
};
|
||||||
|
|
||||||
const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus; errorReason?: string | null }) => {
|
const ModelInstallQueueBadge = ({
|
||||||
|
status,
|
||||||
|
errorReason,
|
||||||
|
}: {
|
||||||
|
status?: ModelInstallStatus;
|
||||||
|
errorReason?: string | null;
|
||||||
|
}) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
if (!status || !Object.keys(STATUSES).includes(status)) {
|
if (!status || !Object.keys(STATUSES).includes(status)) {
|
||||||
return <></>;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -25,4 +32,4 @@ const ImportQueueBadge = ({ status, errorReason }: { status?: ModelInstallStatus
|
|||||||
</Tooltip>
|
</Tooltip>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
export default memo(ImportQueueBadge);
|
export default memo(ModelInstallQueueBadge);
|
@ -3,15 +3,16 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
|
import { isNil } from 'lodash-es';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
import { useDeleteModelImportMutation } from 'services/api/endpoints/models';
|
import { useCancelModelInstallMutation } from 'services/api/endpoints/models';
|
||||||
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
import type { HFModelSource, LocalModelSource, ModelInstallJob, URLModelSource } from 'services/api/types';
|
||||||
|
|
||||||
import ImportQueueBadge from './ImportQueueBadge';
|
import ModelInstallQueueBadge from './ModelInstallQueueBadge';
|
||||||
|
|
||||||
type ModelListItemProps = {
|
type ModelListItemProps = {
|
||||||
model: ModelInstallJob;
|
installJob: ModelInstallJob;
|
||||||
};
|
};
|
||||||
|
|
||||||
const formatBytes = (bytes: number) => {
|
const formatBytes = (bytes: number) => {
|
||||||
@ -26,26 +27,26 @@ const formatBytes = (bytes: number) => {
|
|||||||
return `${bytes.toFixed(2)} ${units[i]}`;
|
return `${bytes.toFixed(2)} ${units[i]}`;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const ImportQueueItem = (props: ModelListItemProps) => {
|
export const ModelInstallQueueItem = (props: ModelListItemProps) => {
|
||||||
const { model } = props;
|
const { installJob } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
const [deleteImportModel] = useDeleteModelImportMutation();
|
const [deleteImportModel] = useCancelModelInstallMutation();
|
||||||
|
|
||||||
const source = useMemo(() => {
|
const source = useMemo(() => {
|
||||||
if (model.source.type === 'hf') {
|
if (installJob.source.type === 'hf') {
|
||||||
return model.source as HFModelSource;
|
return installJob.source as HFModelSource;
|
||||||
} else if (model.source.type === 'local') {
|
} else if (installJob.source.type === 'local') {
|
||||||
return model.source as LocalModelSource;
|
return installJob.source as LocalModelSource;
|
||||||
} else if (model.source.type === 'url') {
|
} else if (installJob.source.type === 'url') {
|
||||||
return model.source as URLModelSource;
|
return installJob.source as URLModelSource;
|
||||||
} else {
|
} else {
|
||||||
return model.source as LocalModelSource;
|
return installJob.source as LocalModelSource;
|
||||||
}
|
}
|
||||||
}, [model.source]);
|
}, [installJob.source]);
|
||||||
|
|
||||||
const handleDeleteModelImport = useCallback(() => {
|
const handleDeleteModelImport = useCallback(() => {
|
||||||
deleteImportModel(model.id)
|
deleteImportModel(installJob.id)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -69,7 +70,7 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}, [deleteImportModel, model, dispatch]);
|
}, [deleteImportModel, installJob, dispatch]);
|
||||||
|
|
||||||
const modelName = useMemo(() => {
|
const modelName = useMemo(() => {
|
||||||
switch (source.type) {
|
switch (source.type) {
|
||||||
@ -85,19 +86,23 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
|||||||
}, [source]);
|
}, [source]);
|
||||||
|
|
||||||
const progressValue = useMemo(() => {
|
const progressValue = useMemo(() => {
|
||||||
if (model.bytes === undefined || model.total_bytes === undefined) {
|
if (isNil(installJob.bytes) || isNil(installJob.total_bytes)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (installJob.total_bytes === 0) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return (model.bytes / model.total_bytes) * 100;
|
return (installJob.bytes / installJob.total_bytes) * 100;
|
||||||
}, [model.bytes, model.total_bytes]);
|
}, [installJob.bytes, installJob.total_bytes]);
|
||||||
|
|
||||||
const progressString = useMemo(() => {
|
const progressString = useMemo(() => {
|
||||||
if (model.status !== 'downloading' || model.bytes === undefined || model.total_bytes === undefined) {
|
if (installJob.status !== 'downloading' || installJob.bytes === undefined || installJob.total_bytes === undefined) {
|
||||||
return '';
|
return '';
|
||||||
}
|
}
|
||||||
return `${formatBytes(model.bytes)} / ${formatBytes(model.total_bytes)}`;
|
return `${formatBytes(installJob.bytes)} / ${formatBytes(installJob.total_bytes)}`;
|
||||||
}, [model.bytes, model.total_bytes, model.status]);
|
}, [installJob.bytes, installJob.total_bytes, installJob.status]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap="2" w="full" alignItems="center">
|
<Flex gap="2" w="full" alignItems="center">
|
||||||
@ -109,19 +114,21 @@ export const ImportQueueItem = (props: ModelListItemProps) => {
|
|||||||
<Flex flexDir="column" flex={1}>
|
<Flex flexDir="column" flex={1}>
|
||||||
<Tooltip label={progressString}>
|
<Tooltip label={progressString}>
|
||||||
<Progress
|
<Progress
|
||||||
value={progressValue}
|
value={progressValue ?? 0}
|
||||||
isIndeterminate={progressValue === undefined}
|
isIndeterminate={progressValue === null}
|
||||||
aria-label={t('accessibility.invokeProgressBar')}
|
aria-label={t('accessibility.invokeProgressBar')}
|
||||||
h={2}
|
h={2}
|
||||||
/>
|
/>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Box minW="100px" textAlign="center">
|
<Box minW="100px" textAlign="center">
|
||||||
<ImportQueueBadge status={model.status} errorReason={model.error_reason} />
|
<ModelInstallQueueBadge status={installJob.status} errorReason={installJob.error_reason} />
|
||||||
</Box>
|
</Box>
|
||||||
|
|
||||||
<Box minW="20px">
|
<Box minW="20px">
|
||||||
{(model.status === 'downloading' || model.status === 'waiting' || model.status === 'running') && (
|
{(installJob.status === 'downloading' ||
|
||||||
|
installJob.status === 'waiting' ||
|
||||||
|
installJob.status === 'running') && (
|
||||||
<IconButton
|
<IconButton
|
||||||
isRound={true}
|
isRound={true}
|
||||||
size="xs"
|
size="xs"
|
@ -2,24 +2,24 @@ import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@
|
|||||||
import type { ChangeEventHandler } from 'react';
|
import type { ChangeEventHandler } from 'react';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useLazyScanModelsQuery } from 'services/api/endpoints/models';
|
import { useLazyScanFolderQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { ScanModelsResults } from './ScanModelsResults';
|
import { ScanModelsResults } from './ScanFolderResults';
|
||||||
|
|
||||||
export const ScanModelsForm = () => {
|
export const ScanModelsForm = () => {
|
||||||
const [scanPath, setScanPath] = useState('');
|
const [scanPath, setScanPath] = useState('');
|
||||||
const [errorMessage, setErrorMessage] = useState('');
|
const [errorMessage, setErrorMessage] = useState('');
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [_scanModels, { isLoading, data }] = useLazyScanModelsQuery();
|
const [_scanFolder, { isLoading, data }] = useLazyScanFolderQuery();
|
||||||
|
|
||||||
const handleSubmitScan = useCallback(async () => {
|
const scanFolder = useCallback(async () => {
|
||||||
_scanModels({ scan_path: scanPath }).catch((error) => {
|
_scanFolder({ scan_path: scanPath }).catch((error) => {
|
||||||
if (error) {
|
if (error) {
|
||||||
setErrorMessage(error.data.detail);
|
setErrorMessage(error.data.detail);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}, [_scanModels, scanPath]);
|
}, [_scanFolder, scanPath]);
|
||||||
|
|
||||||
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
const handleSetScanPath: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||||
setScanPath(e.target.value);
|
setScanPath(e.target.value);
|
||||||
@ -36,7 +36,7 @@ export const ScanModelsForm = () => {
|
|||||||
<Input value={scanPath} onChange={handleSetScanPath} />
|
<Input value={scanPath} onChange={handleSetScanPath} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Button onClick={handleSubmitScan} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={scanPath.length === 0}>
|
||||||
{t('modelManager.scanFolder')}
|
{t('modelManager.scanFolder')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
@ -18,7 +18,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { PiXBold } from 'react-icons/pi';
|
import { PiXBold } from 'react-icons/pi';
|
||||||
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
|
import { type ScanFolderResponse, useInstallModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { ScanModelResultItem } from './ScanModelResultItem';
|
import { ScanModelResultItem } from './ScanFolderResultItem';
|
||||||
|
|
||||||
type ScanModelResultsProps = {
|
type ScanModelResultsProps = {
|
||||||
results: ScanFolderResponse;
|
results: ScanFolderResponse;
|
@ -1,12 +1,11 @@
|
|||||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { AdvancedImport } from './AddModelPanel/AdvancedImport';
|
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||||
import { ImportQueue } from './AddModelPanel/ImportQueue/ImportQueue';
|
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||||
import { ScanModelsForm } from './AddModelPanel/ScanModels/ScanModelsForm';
|
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||||
import { SimpleImport } from './AddModelPanel/SimpleImport';
|
|
||||||
|
|
||||||
export const ImportModels = () => {
|
export const InstallModels = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
return (
|
return (
|
||||||
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
<Flex layerStyle="first" p={3} borderRadius="base" w="full" h="full" flexDir="column" gap={2}>
|
||||||
@ -17,15 +16,11 @@ export const ImportModels = () => {
|
|||||||
<Tabs variant="collapse" height="100%">
|
<Tabs variant="collapse" height="100%">
|
||||||
<TabList>
|
<TabList>
|
||||||
<Tab>{t('common.simple')}</Tab>
|
<Tab>{t('common.simple')}</Tab>
|
||||||
<Tab>{t('modelManager.advanced')}</Tab>
|
|
||||||
<Tab>{t('modelManager.scan')}</Tab>
|
<Tab>{t('modelManager.scan')}</Tab>
|
||||||
</TabList>
|
</TabList>
|
||||||
<TabPanels p={3} height="100%">
|
<TabPanels p={3} height="100%">
|
||||||
<TabPanel>
|
<TabPanel>
|
||||||
<SimpleImport />
|
<InstallModelForm />
|
||||||
</TabPanel>
|
|
||||||
<TabPanel height="100%">
|
|
||||||
<AdvancedImport />
|
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
<TabPanel height="100%">
|
<TabPanel height="100%">
|
||||||
<ScanModelsForm />
|
<ScanModelsForm />
|
||||||
@ -34,7 +29,7 @@ export const ImportModels = () => {
|
|||||||
</Tabs>
|
</Tabs>
|
||||||
</Box>
|
</Box>
|
||||||
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
|
<Box layerStyle="second" borderRadius="base" w="full" h="50%">
|
||||||
<ImportQueue />
|
<ModelInstallQueue />
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
@ -5,7 +5,7 @@ import { useCallback } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoFilter } from 'react-icons/io5';
|
import { IoFilter } from 'react-icons/io5';
|
||||||
|
|
||||||
export const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
const MODEL_TYPE_LABELS: { [key: string]: string } = {
|
||||||
main: 'Main',
|
main: 'Main',
|
||||||
lora: 'LoRA',
|
lora: 'LoRA',
|
||||||
embedding: 'Textual Inversion',
|
embedding: 'Textual Inversion',
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import { Box } from '@invoke-ai/ui-library';
|
import { Box } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
|
||||||
import { ImportModels } from './ImportModels';
|
import { InstallModels } from './InstallModels';
|
||||||
import { Model } from './ModelPanel/Model';
|
import { Model } from './ModelPanel/Model';
|
||||||
|
|
||||||
export const ModelPane = () => {
|
export const ModelPane = () => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
return (
|
return (
|
||||||
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
<Box layerStyle="first" p={2} borderRadius="base" w="50%" h="full">
|
||||||
{selectedModelKey ? <Model key={selectedModelKey} /> : <ImportModels />}
|
{selectedModelKey ? <Model key={selectedModelKey} /> : <InstallModels />}
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -5,7 +5,7 @@ import Loading from 'common/components/Loading/Loading';
|
|||||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||||
import { isNil } from 'lodash-es';
|
import { isNil } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
import { DefaultSettingsForm } from './DefaultSettings/DefaultSettingsForm';
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ const initialStatesSelector = createMemoizedSelector(selectConfigSlice, (config)
|
|||||||
export const DefaultSettings = () => {
|
export const DefaultSettings = () => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
|
|
||||||
const { data, isLoading } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
const { initialSteps, initialCfg, initialScheduler, initialCfgRescaleMultiplier, initialVaePrecision } =
|
||||||
useAppSelector(initialStatesSelector);
|
useAppSelector(initialStatesSelector);
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import type { SubmitHandler } from 'react-hook-form';
|
|||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { IoPencil } from 'react-icons/io5';
|
import { IoPencil } from 'react-icons/io5';
|
||||||
import { useUpdateModelMetadataMutation } from 'services/api/endpoints/models';
|
import { useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
import { DefaultCfgRescaleMultiplier } from './DefaultCfgRescaleMultiplier';
|
||||||
import { DefaultCfgScale } from './DefaultCfgScale';
|
import { DefaultCfgScale } from './DefaultCfgScale';
|
||||||
@ -41,7 +41,7 @@ export const DefaultSettingsForm = ({
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
|
|
||||||
const [editModelMetadata, { isLoading }] = useUpdateModelMetadataMutation();
|
const [updateModel, { isLoading }] = useUpdateModelMutation();
|
||||||
|
|
||||||
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
const { handleSubmit, control, formState } = useForm<DefaultSettingsFormData>({
|
||||||
defaultValues: defaultSettingsDefaults,
|
defaultValues: defaultSettingsDefaults,
|
||||||
@ -62,7 +62,7 @@ export const DefaultSettingsForm = ({
|
|||||||
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
scheduler: data.scheduler.isEnabled ? data.scheduler.value : null,
|
||||||
};
|
};
|
||||||
|
|
||||||
editModelMetadata({
|
updateModel({
|
||||||
key: selectedModelKey,
|
key: selectedModelKey,
|
||||||
body: { default_settings: body },
|
body: { default_settings: body },
|
||||||
})
|
})
|
||||||
@ -90,7 +90,7 @@ export const DefaultSettingsForm = ({
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
[selectedModelKey, dispatch, editModelMetadata, t]
|
[selectedModelKey, dispatch, updateModel, t]
|
||||||
);
|
);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -3,9 +3,9 @@ import { Combobox } from '@invoke-ai/ui-library';
|
|||||||
import { typedMemo } from 'common/util/typedMemo';
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { Control } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
const options: ComboboxOption[] = [
|
||||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||||
@ -14,8 +14,12 @@ const options: ComboboxOption[] = [
|
|||||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||||
];
|
];
|
||||||
|
|
||||||
const BaseModelSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
type Props = {
|
||||||
const { field } = useController(props);
|
control: Control<UpdateModelArg['body']>;
|
||||||
|
};
|
||||||
|
|
||||||
|
const BaseModelSelect = ({ control }: Props) => {
|
||||||
|
const { field } = useController({ control, name: 'base' });
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
|
@ -1,27 +0,0 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Combobox } from '@invoke-ai/ui-library';
|
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
|
||||||
import { useController } from 'react-hook-form';
|
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
|
||||||
{ value: 'none', label: '-' },
|
|
||||||
{ value: 'true', label: 'True' },
|
|
||||||
{ value: 'false', label: 'False' },
|
|
||||||
];
|
|
||||||
|
|
||||||
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
|
||||||
const { field } = useController(props);
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
|
||||||
(v) => {
|
|
||||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value === 'true');
|
|
||||||
},
|
|
||||||
[field]
|
|
||||||
);
|
|
||||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default typedMemo(BooleanSelect);
|
|
@ -1,47 +0,0 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Combobox } from '@invoke-ai/ui-library';
|
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
|
||||||
import { useController, useWatch } from 'react-hook-form';
|
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const ModelFormatSelect = (props: UseControllerProps<AnyModelConfig>) => {
|
|
||||||
const { field, formState } = useController(props);
|
|
||||||
const type = useWatch({ control: props.control, name: 'type' });
|
|
||||||
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
|
||||||
(v) => {
|
|
||||||
field.onChange(v?.value);
|
|
||||||
},
|
|
||||||
[field]
|
|
||||||
);
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = useMemo(() => {
|
|
||||||
const modelType = type || formState.defaultValues?.type;
|
|
||||||
if (modelType === 'lora') {
|
|
||||||
return [
|
|
||||||
{ value: 'lycoris', label: 'LyCORIS' },
|
|
||||||
{ value: 'diffusers', label: 'Diffusers' },
|
|
||||||
];
|
|
||||||
} else if (modelType === 'embedding') {
|
|
||||||
return [
|
|
||||||
{ value: 'embedding_file', label: 'Embedding File' },
|
|
||||||
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
|
||||||
];
|
|
||||||
} else if (modelType === 'ip_adapter') {
|
|
||||||
return [{ value: 'invokeai', label: 'invokeai' }];
|
|
||||||
} else {
|
|
||||||
return [
|
|
||||||
{ value: 'diffusers', label: 'Diffusers' },
|
|
||||||
{ value: 'checkpoint', label: 'Checkpoint' },
|
|
||||||
];
|
|
||||||
}
|
|
||||||
}, [type, formState.defaultValues?.type]);
|
|
||||||
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
|
||||||
|
|
||||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default typedMemo(ModelFormatSelect);
|
|
@ -1,32 +0,0 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Combobox } from '@invoke-ai/ui-library';
|
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
|
||||||
import { MODEL_TYPE_LABELS } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
|
||||||
import { useController } from 'react-hook-form';
|
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
|
||||||
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
|
|
||||||
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
|
|
||||||
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
|
|
||||||
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
|
|
||||||
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
|
|
||||||
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
|
|
||||||
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
|
|
||||||
] as const;
|
|
||||||
|
|
||||||
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
|
||||||
const { field } = useController(props);
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
|
||||||
(v) => {
|
|
||||||
field.onChange(v?.value);
|
|
||||||
},
|
|
||||||
[field]
|
|
||||||
);
|
|
||||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default typedMemo(ModelTypeSelect);
|
|
@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|||||||
import { Combobox } from '@invoke-ai/ui-library';
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { Control } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
const options: ComboboxOption[] = [
|
||||||
{ value: 'normal', label: 'Normal' },
|
{ value: 'normal', label: 'Normal' },
|
||||||
@ -12,8 +12,12 @@ const options: ComboboxOption[] = [
|
|||||||
{ value: 'depth', label: 'Depth' },
|
{ value: 'depth', label: 'Depth' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
type Props = {
|
||||||
const { field } = useController(props);
|
control: Control<UpdateModelArg['body']>;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ModelVariantSelect = ({ control }: Props) => {
|
||||||
|
const { field } = useController({ control, name: 'variant' });
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
|
@ -2,9 +2,9 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|||||||
import { Combobox } from '@invoke-ai/ui-library';
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { Control } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController } from 'react-hook-form';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
const options: ComboboxOption[] = [
|
||||||
{ value: 'none', label: '-' },
|
{ value: 'none', label: '-' },
|
||||||
@ -13,8 +13,12 @@ const options: ComboboxOption[] = [
|
|||||||
{ value: 'sample', label: 'sample' },
|
{ value: 'sample', label: 'sample' },
|
||||||
];
|
];
|
||||||
|
|
||||||
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
type Props = {
|
||||||
const { field } = useController(props);
|
control: Control<UpdateModelArg['body']>;
|
||||||
|
};
|
||||||
|
|
||||||
|
const PredictionTypeSelect = ({ control }: Props) => {
|
||||||
|
const { field } = useController({ control, name: 'prediction_type' });
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
(v) => {
|
(v) => {
|
||||||
|
@ -1,27 +0,0 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Combobox } from '@invoke-ai/ui-library';
|
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
|
||||||
import { useController } from 'react-hook-form';
|
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
|
||||||
{ value: 'none', label: '-' },
|
|
||||||
{ value: 'fp16', label: 'fp16' },
|
|
||||||
{ value: 'fp32', label: 'fp32' },
|
|
||||||
];
|
|
||||||
|
|
||||||
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
|
||||||
const { field } = useController(props);
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
|
||||||
(v) => {
|
|
||||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
|
||||||
},
|
|
||||||
[field]
|
|
||||||
);
|
|
||||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default typedMemo(RepoVariantSelect);
|
|
@ -2,16 +2,16 @@ import { Flex } from '@invoke-ai/ui-library';
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||||
import { useGetModelMetadataQuery } from 'services/api/endpoints/models';
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
export const ModelMetadata = () => {
|
export const ModelMetadata = () => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
const { data } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Flex flexDir="column" height="full" gap="3">
|
<Flex flexDir="column" height="full" gap="3">
|
||||||
<DataViewer label="metadata" data={metadata || {}} />
|
<DataViewer label="metadata" data={data?.source_api_response || {}} />
|
||||||
</Flex>
|
</Flex>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
@ -13,7 +13,7 @@ import { addToast } from 'features/system/store/systemSlice';
|
|||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
import { useConvertModelMutation } from 'services/api/endpoints/models';
|
||||||
import type { CheckpointModelConfig } from 'services/api/types';
|
import type { CheckpointModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
interface ModelConvertProps {
|
interface ModelConvertProps {
|
||||||
@ -24,7 +24,7 @@ export const ModelConvert = (props: ModelConvertProps) => {
|
|||||||
const { model } = props;
|
const { model } = props;
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
const [convertModel, { isLoading }] = useConvertModelMutation();
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
|
|
||||||
const modelConvertHandler = useCallback(() => {
|
const modelConvertHandler = useCallback(() => {
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import {
|
import {
|
||||||
Button,
|
Button,
|
||||||
|
Checkbox,
|
||||||
Flex,
|
Flex,
|
||||||
FormControl,
|
FormControl,
|
||||||
FormErrorMessage,
|
FormErrorMessage,
|
||||||
@ -19,66 +20,27 @@ import type { SubmitHandler } from 'react-hook-form';
|
|||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||||
import { useGetModelConfigQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||||
import BooleanSelect from './Fields/BooleanSelect';
|
|
||||||
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
|
||||||
import ModelTypeSelect from './Fields/ModelTypeSelect';
|
|
||||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||||
import RepoVariantSelect from './Fields/RepoVariantSelect';
|
|
||||||
|
|
||||||
export const ModelEdit = () => {
|
export const ModelEdit = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
|
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
// const modelData = useMemo(() => {
|
|
||||||
// if (!data) {
|
|
||||||
// return null;
|
|
||||||
// }
|
|
||||||
// const modelFormat = data.format;
|
|
||||||
// const modelType = data.type;
|
|
||||||
|
|
||||||
// if (modelType === 'main') {
|
|
||||||
// if (modelFormat === 'diffusers') {
|
|
||||||
// return data as DiffusersModelConfig;
|
|
||||||
// } else if (modelFormat === 'checkpoint') {
|
|
||||||
// return data as CheckpointModelConfig;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// switch (modelType) {
|
|
||||||
// case 'lora':
|
|
||||||
// return data as LoRAModelConfig;
|
|
||||||
// case 'embedding':
|
|
||||||
// return data as TextualInversionModelConfig;
|
|
||||||
// case 't2i_adapter':
|
|
||||||
// return data as T2IAdapterModelConfig;
|
|
||||||
// case 'ip_adapter':
|
|
||||||
// return data as IPAdapterModelConfig;
|
|
||||||
// case 'controlnet':
|
|
||||||
// return data as ControlNetModelConfig;
|
|
||||||
// case 'vae':
|
|
||||||
// return data as VAEModelConfig;
|
|
||||||
// default:
|
|
||||||
// return null;
|
|
||||||
// }
|
|
||||||
// }, [data]);
|
|
||||||
|
|
||||||
const {
|
const {
|
||||||
register,
|
register,
|
||||||
handleSubmit,
|
handleSubmit,
|
||||||
control,
|
control,
|
||||||
formState: { errors },
|
formState: { errors },
|
||||||
reset,
|
reset,
|
||||||
watch,
|
|
||||||
} = useForm<UpdateModelArg['body']>({
|
} = useForm<UpdateModelArg['body']>({
|
||||||
defaultValues: {
|
defaultValues: {
|
||||||
...data,
|
...data,
|
||||||
@ -86,10 +48,7 @@ export const ModelEdit = () => {
|
|||||||
mode: 'onChange',
|
mode: 'onChange',
|
||||||
});
|
});
|
||||||
|
|
||||||
const watchedModelType = watch('type');
|
const onSubmit = useCallback<SubmitHandler<UpdateModelArg['body']>>(
|
||||||
const watchedModelFormat = watch('format');
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<AnyModelConfig>>(
|
|
||||||
(values) => {
|
(values) => {
|
||||||
if (!data?.key) {
|
if (!data?.key) {
|
||||||
return;
|
return;
|
||||||
@ -143,33 +102,31 @@ export const ModelEdit = () => {
|
|||||||
return (
|
return (
|
||||||
<Flex flexDir="column" h="full">
|
<Flex flexDir="column" h="full">
|
||||||
<form onSubmit={handleSubmit(onSubmit)}>
|
<form onSubmit={handleSubmit(onSubmit)}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.name)}>
|
||||||
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
|
<FormLabel hidden={true}>{t('modelManager.modelName')}</FormLabel>
|
||||||
<Input
|
<Input
|
||||||
{...register('name', {
|
{...register('name', {
|
||||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
validate: (value) => (value && value.trim().length > 3) || 'Must be at least 3 characters',
|
||||||
})}
|
})}
|
||||||
size="lg"
|
size="lg"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<Flex gap={2}>
|
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||||
<Button size="sm" onClick={handleClickCancel}>
|
</FormControl>
|
||||||
{t('common.cancel')}
|
<Button size="sm" onClick={handleClickCancel}>
|
||||||
</Button>
|
{t('common.cancel')}
|
||||||
<Button
|
</Button>
|
||||||
size="sm"
|
<Button
|
||||||
colorScheme="invokeYellow"
|
size="sm"
|
||||||
onClick={handleSubmit(onSubmit)}
|
colorScheme="invokeYellow"
|
||||||
isLoading={isSubmitting}
|
onClick={handleSubmit(onSubmit)}
|
||||||
isDisabled={Boolean(Object.keys(errors).length)}
|
isLoading={isSubmitting}
|
||||||
>
|
isDisabled={Boolean(Object.keys(errors).length)}
|
||||||
{t('common.save')}
|
>
|
||||||
</Button>
|
{t('common.save')}
|
||||||
</Flex>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<Flex flexDir="column" gap={3} mt="4">
|
<Flex flexDir="column" gap={3} mt="4">
|
||||||
<Flex>
|
<Flex>
|
||||||
@ -184,76 +141,22 @@ export const ModelEdit = () => {
|
|||||||
<Flex gap={4}>
|
<Flex gap={4}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||||
<BaseModelSelect control={control} name="base" />
|
<BaseModelSelect control={control} />
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
|
||||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex gap={4}>
|
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('common.format')}</FormLabel>
|
|
||||||
<ModelFormatSelect control={control} name="format" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1} isInvalid={Boolean(errors.path)}>
|
|
||||||
<FormLabel>{t('modelManager.path')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
{watchedModelType === 'main' && (
|
|
||||||
<>
|
|
||||||
<Flex gap={4}>
|
|
||||||
{watchedModelFormat === 'diffusers' && (
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.repoVariant')}</FormLabel>
|
|
||||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
{watchedModelFormat === 'checkpoint' && (
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
|
||||||
<Input {...register('config')} />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
|
||||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
|
||||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
|
||||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.ztsnrTraining')}</FormLabel>
|
|
||||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
|
||||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
|
||||||
<Input {...register('vae')} />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{watchedModelType === 'ip_adapter' && (
|
|
||||||
<Flex gap={4}>
|
<Flex gap={4}>
|
||||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
<FormLabel>{t('modelManager.imageEncoderModelId')}</FormLabel>
|
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||||
<Input {...register('image_encoder_model_id')} />
|
<ModelVariantSelect control={control} />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||||
|
<PredictionTypeSelect control={control} />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||||
|
<Checkbox {...register('upcast_attention')} />
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
|
@ -91,26 +91,19 @@ export const ModelView = () => {
|
|||||||
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
<ModelAttrView label={t('modelManager.path')} value={modelData.path} />
|
||||||
</Flex>
|
</Flex>
|
||||||
{modelData.type === 'main' && (
|
{modelData.type === 'main' && (
|
||||||
<>
|
<Flex gap={2}>
|
||||||
<Flex gap={2}>
|
{modelData.format === 'diffusers' && modelData.repo_variant && (
|
||||||
{modelData.format === 'diffusers' && (
|
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
||||||
<ModelAttrView label={t('modelManager.repoVariant')} value={modelData.repo_variant} />
|
)}
|
||||||
)}
|
{modelData.format === 'checkpoint' && (
|
||||||
{modelData.format === 'checkpoint' && (
|
<>
|
||||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config} />
|
<ModelAttrView label={t('modelManager.pathToConfig')} value={modelData.config_path} />
|
||||||
)}
|
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
||||||
|
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
||||||
<ModelAttrView label={t('modelManager.variant')} value={modelData.variant} />
|
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
||||||
</Flex>
|
</>
|
||||||
<Flex gap={2}>
|
)}
|
||||||
<ModelAttrView label={t('modelManager.predictionType')} value={modelData.prediction_type} />
|
</Flex>
|
||||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${modelData.upcast_attention}`} />
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<ModelAttrView label={t('modelManager.ztsnrTraining')} value={`${modelData.ztsnr_training}`} />
|
|
||||||
<ModelAttrView label={t('modelManager.vae')} value={modelData.vae} />
|
|
||||||
</Flex>
|
|
||||||
</>
|
|
||||||
)}
|
)}
|
||||||
{modelData.type === 'ip_adapter' && (
|
{modelData.type === 'ip_adapter' && (
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
import type { EntityAdapter, EntityState, ThunkDispatch, UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||||
import type { JSONObject } from 'common/types';
|
|
||||||
import queryString from 'query-string';
|
import queryString from 'query-string';
|
||||||
import type { operations, paths } from 'services/api/schema';
|
import type { operations, paths } from 'services/api/schema';
|
||||||
import type {
|
import type {
|
||||||
@ -24,49 +23,33 @@ export type UpdateModelArg = {
|
|||||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||||
};
|
};
|
||||||
|
|
||||||
type UpdateModelMetadataArg = {
|
|
||||||
key: paths['/api/v2/models/i/{key}/metadata']['patch']['parameters']['path']['key'];
|
|
||||||
body: paths['/api/v2/models/i/{key}/metadata']['patch']['requestBody']['content']['application/json'];
|
|
||||||
};
|
|
||||||
|
|
||||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||||
type UpdateModelMetadataResponse =
|
|
||||||
paths['/api/v2/models/i/{key}/metadata']['patch']['responses']['200']['content']['application/json'];
|
|
||||||
|
|
||||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type GetModelMetadataResponse =
|
|
||||||
paths['/api/v2/models/i/{key}/metadata']['get']['responses']['200']['content']['application/json'];
|
|
||||||
|
|
||||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||||
|
|
||||||
type DeleteMainModelArg = {
|
type DeleteModelArg = {
|
||||||
key: string;
|
key: string;
|
||||||
};
|
};
|
||||||
|
type DeleteModelResponse = void;
|
||||||
type DeleteMainModelResponse = void;
|
|
||||||
|
|
||||||
type ConvertMainModelResponse =
|
type ConvertMainModelResponse =
|
||||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type InstallModelArg = {
|
type InstallModelArg = {
|
||||||
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
|
source: paths['/api/v2/models/install']['post']['parameters']['query']['source'];
|
||||||
access_token?: paths['/api/v2/models/install']['post']['parameters']['query']['access_token'];
|
|
||||||
// TODO(MM2): This is typed as `Optional[Dict[str, Any]]` in backend...
|
|
||||||
config?: JSONObject;
|
|
||||||
// config: NonNullable<paths['/api/v2/models/install']['post']['requestBody']>['content']['application/json'];
|
|
||||||
};
|
};
|
||||||
|
|
||||||
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
|
type InstallModelResponse = paths['/api/v2/models/install']['post']['responses']['201']['content']['application/json'];
|
||||||
|
|
||||||
type ListImportModelsResponse =
|
type ListModelInstallsResponse =
|
||||||
paths['/api/v2/models/import']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/install']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type DeleteImportModelsResponse =
|
type CancelModelInstallResponse =
|
||||||
paths['/api/v2/models/import/{id}']['delete']['responses']['201']['content']['application/json'];
|
paths['/api/v2/models/install/{id}']['delete']['responses']['201']['content']['application/json'];
|
||||||
|
|
||||||
type PruneModelImportsResponse =
|
type PruneCompletedModelInstallsResponse =
|
||||||
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/install']['delete']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
export type ScanFolderResponse =
|
export type ScanFolderResponse =
|
||||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||||
@ -146,31 +129,7 @@ const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
|||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
updateModel: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
||||||
query: (base_models) => {
|
|
||||||
const params: ListModelsArg = {
|
|
||||||
model_type: 'main',
|
|
||||||
base_models,
|
|
||||||
};
|
|
||||||
|
|
||||||
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
|
||||||
return buildModelsUrl(`?${query}`);
|
|
||||||
},
|
|
||||||
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
|
||||||
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
|
||||||
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
|
||||||
queryFulfilled.then(({ data }) => {
|
|
||||||
upsertModelConfigs(data, dispatch);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getModelMetadata: build.query<GetModelMetadataResponse, string>({
|
|
||||||
query: (key) => {
|
|
||||||
return buildModelsUrl(`i/${key}/metadata`);
|
|
||||||
},
|
|
||||||
providesTags: ['Model'],
|
|
||||||
}),
|
|
||||||
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
|
||||||
query: ({ key, body }) => {
|
query: ({ key, body }) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl(`i/${key}`),
|
url: buildModelsUrl(`i/${key}`),
|
||||||
@ -180,28 +139,17 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
updateModelMetadata: build.mutation<UpdateModelMetadataResponse, UpdateModelMetadataArg>({
|
|
||||||
query: ({ key, body }) => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl(`i/${key}/metadata`),
|
|
||||||
method: 'PATCH',
|
|
||||||
body: body,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
invalidatesTags: ['Model'],
|
|
||||||
}),
|
|
||||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||||
query: ({ source, config, access_token }) => {
|
query: ({ source }) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl('install'),
|
url: buildModelsUrl('install'),
|
||||||
params: { source, access_token },
|
params: { source },
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: config,
|
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: ['Model', 'ModelImports'],
|
invalidatesTags: ['Model', 'ModelInstalls'],
|
||||||
}),
|
}),
|
||||||
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
deleteModels: build.mutation<DeleteModelResponse, DeleteModelArg>({
|
||||||
query: ({ key }) => {
|
query: ({ key }) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl(`i/${key}`),
|
url: buildModelsUrl(`i/${key}`),
|
||||||
@ -210,7 +158,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
convertMainModels: build.mutation<ConvertMainModelResponse, string>({
|
convertModel: build.mutation<ConvertMainModelResponse, string>({
|
||||||
query: (key) => {
|
query: (key) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl(`convert/${key}`),
|
url: buildModelsUrl(`convert/${key}`),
|
||||||
@ -253,6 +201,57 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
|
scanFolder: build.query<ScanFolderResponse, ScanFolderArg>({
|
||||||
|
query: (arg) => {
|
||||||
|
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
listModelInstalls: build.query<ListModelInstallsResponse, void>({
|
||||||
|
query: () => {
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl('install'),
|
||||||
|
};
|
||||||
|
},
|
||||||
|
providesTags: ['ModelInstalls'],
|
||||||
|
}),
|
||||||
|
cancelModelInstall: build.mutation<CancelModelInstallResponse, number>({
|
||||||
|
query: (id) => {
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl(`install/${id}`),
|
||||||
|
method: 'DELETE',
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: ['ModelInstalls'],
|
||||||
|
}),
|
||||||
|
pruneCompletedModelInstalls: build.mutation<PruneCompletedModelInstallsResponse, void>({
|
||||||
|
query: () => {
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl('install'),
|
||||||
|
method: 'DELETE',
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: ['ModelInstalls'],
|
||||||
|
}),
|
||||||
|
getMainModels: build.query<EntityState<MainModelConfig, string>, BaseModelType[]>({
|
||||||
|
query: (base_models) => {
|
||||||
|
const params: ListModelsArg = {
|
||||||
|
model_type: 'main',
|
||||||
|
base_models,
|
||||||
|
};
|
||||||
|
const query = queryString.stringify(params, { arrayFormat: 'none' });
|
||||||
|
return buildModelsUrl(`?${query}`);
|
||||||
|
},
|
||||||
|
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
|
||||||
|
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
|
||||||
|
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
|
||||||
|
queryFulfilled.then(({ data }) => {
|
||||||
|
upsertModelConfigs(data, dispatch);
|
||||||
|
});
|
||||||
|
},
|
||||||
|
}),
|
||||||
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
getLoRAModels: build.query<EntityState<LoRAModelConfig, string>, void>({
|
||||||
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
query: () => ({ url: buildModelsUrl(), params: { model_type: 'lora' } }),
|
||||||
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
providesTags: buildProvidesTags<LoRAModelConfig>('LoRAModel'),
|
||||||
@ -313,40 +312,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
scanModels: build.query<ScanFolderResponse, ScanFolderArg>({
|
|
||||||
query: (arg) => {
|
|
||||||
const folderQueryStr = arg ? queryString.stringify(arg, {}) : '';
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl(`scan_folder?${folderQueryStr}`),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
getModelImports: build.query<ListImportModelsResponse, void>({
|
|
||||||
query: () => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl(`import`),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
providesTags: ['ModelImports'],
|
|
||||||
}),
|
|
||||||
deleteModelImport: build.mutation<DeleteImportModelsResponse, number>({
|
|
||||||
query: (id) => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl(`import/${id}`),
|
|
||||||
method: 'DELETE',
|
|
||||||
};
|
|
||||||
},
|
|
||||||
invalidatesTags: ['ModelImports'],
|
|
||||||
}),
|
|
||||||
pruneModelImports: build.mutation<PruneModelImportsResponse, void>({
|
|
||||||
query: () => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl('import'),
|
|
||||||
method: 'PATCH',
|
|
||||||
};
|
|
||||||
},
|
|
||||||
invalidatesTags: ['ModelImports'],
|
|
||||||
}),
|
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -360,16 +325,14 @@ export const {
|
|||||||
useGetTextualInversionModelsQuery,
|
useGetTextualInversionModelsQuery,
|
||||||
useGetVaeModelsQuery,
|
useGetVaeModelsQuery,
|
||||||
useDeleteModelsMutation,
|
useDeleteModelsMutation,
|
||||||
useUpdateModelsMutation,
|
useUpdateModelMutation,
|
||||||
useInstallModelMutation,
|
useInstallModelMutation,
|
||||||
useConvertMainModelsMutation,
|
useConvertModelMutation,
|
||||||
useSyncModelsMutation,
|
useSyncModelsMutation,
|
||||||
useLazyScanModelsQuery,
|
useLazyScanFolderQuery,
|
||||||
useGetModelImportsQuery,
|
useListModelInstallsQuery,
|
||||||
useGetModelMetadataQuery,
|
useCancelModelInstallMutation,
|
||||||
useDeleteModelImportMutation,
|
usePruneCompletedModelInstallsMutation,
|
||||||
usePruneModelImportsMutation,
|
|
||||||
useUpdateModelMetadataMutation,
|
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
|
||||||
const upsertModelConfigs = (
|
const upsertModelConfigs = (
|
||||||
|
@ -28,7 +28,7 @@ export const tagTypes = [
|
|||||||
'InvocationCacheStatus',
|
'InvocationCacheStatus',
|
||||||
'Model',
|
'Model',
|
||||||
'ModelConfig',
|
'ModelConfig',
|
||||||
'ModelImports',
|
'ModelInstalls',
|
||||||
'T2IAdapterModel',
|
'T2IAdapterModel',
|
||||||
'MainModel',
|
'MainModel',
|
||||||
'VaeModel',
|
'VaeModel',
|
||||||
|
File diff suppressed because one or more lines are too long
@ -43,14 +43,13 @@ export type ControlField = S['ControlField'];
|
|||||||
// Model Configs
|
// Model Configs
|
||||||
|
|
||||||
// TODO(MM2): Can we make key required in the pydantic model?
|
// TODO(MM2): Can we make key required in the pydantic model?
|
||||||
export type LoRAModelConfig = S['LoRAConfig'];
|
export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
|
||||||
// TODO(MM2): Can we rename this from Vae -> VAE
|
// TODO(MM2): Can we rename this from Vae -> VAE
|
||||||
export type VAEModelConfig = S['VaeCheckpointConfig'] | S['VaeDiffusersConfig'];
|
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||||
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
export type IPAdapterModelConfig = S['IPAdapterConfig'];
|
||||||
// TODO(MM2): Can we rename this to T2IAdapterConfig
|
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||||
export type T2IAdapterModelConfig = S['T2IConfig'];
|
export type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||||
export type TextualInversionModelConfig = S['TextualInversionConfig'];
|
|
||||||
export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
export type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||||
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
export type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||||
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig'];
|
||||||
|
Loading…
Reference in New Issue
Block a user