mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up old model manager components and endpoints
This commit is contained in:
parent
9b1f63379a
commit
baf1194cae
@ -15,8 +15,7 @@ import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynam
|
|||||||
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
|
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
|
||||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||||
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
|
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||||
import { modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
|
||||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||||
@ -55,7 +54,6 @@ const allReducers = {
|
|||||||
[deleteImageModalSlice.name]: deleteImageModalSlice.reducer,
|
[deleteImageModalSlice.name]: deleteImageModalSlice.reducer,
|
||||||
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
|
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
|
||||||
[loraSlice.name]: loraSlice.reducer,
|
[loraSlice.name]: loraSlice.reducer,
|
||||||
[modelManagerSlice.name]: modelManagerSlice.reducer,
|
|
||||||
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
|
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
|
||||||
[sdxlSlice.name]: sdxlSlice.reducer,
|
[sdxlSlice.name]: sdxlSlice.reducer,
|
||||||
[queueSlice.name]: queueSlice.reducer,
|
[queueSlice.name]: queueSlice.reducer,
|
||||||
@ -103,7 +101,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
|||||||
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
|
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
|
||||||
[sdxlPersistConfig.name]: sdxlPersistConfig,
|
[sdxlPersistConfig.name]: sdxlPersistConfig,
|
||||||
[loraPersistConfig.name]: loraPersistConfig,
|
[loraPersistConfig.name]: loraPersistConfig,
|
||||||
[modelManagerPersistConfig.name]: modelManagerPersistConfig,
|
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
|
||||||
import { Button } from '@invoke-ai/ui-library';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { PiArrowsClockwiseBold } from 'react-icons/pi';
|
|
||||||
|
|
||||||
import { useSyncModels } from './useSyncModels';
|
|
||||||
|
|
||||||
export const SyncModelsButton = memo((props: Omit<ButtonProps, 'children'>) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const { syncModels, isLoading } = useSyncModels();
|
|
||||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
|
||||||
|
|
||||||
if (!isSyncModelEnabled) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Button
|
|
||||||
isLoading={isLoading}
|
|
||||||
onClick={syncModels}
|
|
||||||
leftIcon={<PiArrowsClockwiseBold />}
|
|
||||||
minW="max-content"
|
|
||||||
{...props}
|
|
||||||
>
|
|
||||||
{t('modelManager.syncModels')}
|
|
||||||
</Button>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
SyncModelsButton.displayName = 'SyncModelsButton';
|
|
@ -1,33 +0,0 @@
|
|||||||
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
|
||||||
import { IconButton } from '@invoke-ai/ui-library';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { PiArrowsClockwiseBold } from 'react-icons/pi';
|
|
||||||
|
|
||||||
import { useSyncModels } from './useSyncModels';
|
|
||||||
|
|
||||||
export const SyncModelsIconButton = memo((props: Omit<IconButtonProps, 'aria-label'>) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const { syncModels, isLoading } = useSyncModels();
|
|
||||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
|
||||||
|
|
||||||
if (!isSyncModelEnabled) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<IconButton
|
|
||||||
icon={<PiArrowsClockwiseBold />}
|
|
||||||
tooltip={t('modelManager.syncModels')}
|
|
||||||
aria-label={t('modelManager.syncModels')}
|
|
||||||
isLoading={isLoading}
|
|
||||||
onClick={syncModels}
|
|
||||||
size="sm"
|
|
||||||
variant="ghost"
|
|
||||||
{...props}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
SyncModelsIconButton.displayName = 'SyncModelsIconButton';
|
|
@ -1,40 +0,0 @@
|
|||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useCallback } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useSyncModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
export const useSyncModels = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const [_syncModels, { isLoading }] = useSyncModelsMutation();
|
|
||||||
const syncModels = useCallback(() => {
|
|
||||||
_syncModels()
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `${t('modelManager.modelsSynced')}`,
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `${t('modelManager.modelSyncFailed')}`,
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}, [dispatch, _syncModels, t]);
|
|
||||||
|
|
||||||
return { syncModels, isLoading };
|
|
||||||
};
|
|
@ -1,47 +0,0 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
|
||||||
|
|
||||||
type ModelManagerState = {
|
|
||||||
_version: 1;
|
|
||||||
searchFolder: string | null;
|
|
||||||
advancedAddScanModel: string | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const initialModelManagerState: ModelManagerState = {
|
|
||||||
_version: 1,
|
|
||||||
searchFolder: null,
|
|
||||||
advancedAddScanModel: null,
|
|
||||||
};
|
|
||||||
|
|
||||||
export const modelManagerSlice = createSlice({
|
|
||||||
name: 'modelmanager',
|
|
||||||
initialState: initialModelManagerState,
|
|
||||||
reducers: {
|
|
||||||
setSearchFolder: (state, action: PayloadAction<string | null>) => {
|
|
||||||
state.searchFolder = action.payload;
|
|
||||||
},
|
|
||||||
setAdvancedAddScanModel: (state, action: PayloadAction<string | null>) => {
|
|
||||||
state.advancedAddScanModel = action.payload;
|
|
||||||
},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export const { setSearchFolder, setAdvancedAddScanModel } = modelManagerSlice.actions;
|
|
||||||
|
|
||||||
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
|
|
||||||
|
|
||||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
|
||||||
export const migrateModelManagerState = (state: any): any => {
|
|
||||||
if (!('_version' in state)) {
|
|
||||||
state._version = 1;
|
|
||||||
}
|
|
||||||
return state;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const modelManagerPersistConfig: PersistConfig<ModelManagerState> = {
|
|
||||||
name: modelManagerSlice.name,
|
|
||||||
initialState: initialModelManagerState,
|
|
||||||
migrate: migrateModelManagerState,
|
|
||||||
persistDenylist: [],
|
|
||||||
};
|
|
@ -1,35 +0,0 @@
|
|||||||
import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
|
|
||||||
import { memo, useCallback, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useGetModelImportsQuery } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
import AdvancedAddModels from './AdvancedAddModels';
|
|
||||||
import SimpleAddModels from './SimpleAddModels';
|
|
||||||
|
|
||||||
const AddModels = () => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>('simple');
|
|
||||||
const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []);
|
|
||||||
const handleAddModelAdvanced = useCallback(() => setAddModelMode('advanced'), []);
|
|
||||||
const { data } = useGetModelImportsQuery();
|
|
||||||
console.log({ data });
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" width="100%" overflow="scroll" maxHeight={window.innerHeight - 250} gap={4}>
|
|
||||||
<ButtonGroup>
|
|
||||||
<Button size="sm" isChecked={addModelMode === 'simple'} onClick={handleAddModelSimple}>
|
|
||||||
{t('common.simple')}
|
|
||||||
</Button>
|
|
||||||
<Button size="sm" isChecked={addModelMode === 'advanced'} onClick={handleAddModelAdvanced}>
|
|
||||||
{t('common.advanced')}
|
|
||||||
</Button>
|
|
||||||
</ButtonGroup>
|
|
||||||
<Flex p={4} borderRadius={4} bg="base.800">
|
|
||||||
{addModelMode === 'simple' && <SimpleAddModels />}
|
|
||||||
{addModelMode === 'advanced' && <AdvancedAddModels />}
|
|
||||||
</Flex>
|
|
||||||
<Flex>{data?.map((model) => <Text key={model.id}>{model.status}</Text>)}</Flex>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(AddModels);
|
|
@ -1,168 +0,0 @@
|
|||||||
import { Button, Checkbox, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
|
|
||||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
|
||||||
import CheckpointConfigsSelect from 'features/modelManager/subpanels/shared/CheckpointConfigsSelect';
|
|
||||||
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import type { CSSProperties, FocusEventHandler } from 'react';
|
|
||||||
import { memo, useCallback, useState } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { CheckpointModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import { getModelName } from './util';
|
|
||||||
|
|
||||||
type AdvancedAddCheckpointProps = {
|
|
||||||
model_path?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
const AdvancedAddCheckpoint = (props: AdvancedAddCheckpointProps) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { model_path } = props;
|
|
||||||
|
|
||||||
const {
|
|
||||||
register,
|
|
||||||
handleSubmit,
|
|
||||||
control,
|
|
||||||
getValues,
|
|
||||||
setValue,
|
|
||||||
formState: { errors },
|
|
||||||
reset,
|
|
||||||
} = useForm<CheckpointModelConfig>({
|
|
||||||
defaultValues: {
|
|
||||||
model_name: model_path ? getModelName(model_path) : '',
|
|
||||||
base_model: 'sd-1',
|
|
||||||
model_type: 'main',
|
|
||||||
path: model_path ? model_path : '',
|
|
||||||
description: '',
|
|
||||||
model_format: 'checkpoint',
|
|
||||||
error: undefined,
|
|
||||||
vae: '',
|
|
||||||
variant: 'normal',
|
|
||||||
config: 'configs\\stable-diffusion\\v1-inference.yaml',
|
|
||||||
},
|
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const [addMainModel] = useAddMainModelsMutation();
|
|
||||||
|
|
||||||
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
|
|
||||||
(values) => {
|
|
||||||
addMainModel({
|
|
||||||
body: values,
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelAdded', {
|
|
||||||
modelName: values.model_name,
|
|
||||||
}),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
reset();
|
|
||||||
|
|
||||||
// Close Advanced Panel in Scan Models tab
|
|
||||||
if (model_path) {
|
|
||||||
dispatch(setAdvancedAddScanModel(null));
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.modelAddFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[addMainModel, dispatch, model_path, reset, t]
|
|
||||||
);
|
|
||||||
|
|
||||||
const onBlur: FocusEventHandler<HTMLInputElement> = useCallback(
|
|
||||||
(e) => {
|
|
||||||
if (getValues().model_name === '') {
|
|
||||||
const modelName = getModelName(e.currentTarget.value);
|
|
||||||
if (modelName) {
|
|
||||||
setValue('model_name', modelName as string);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[getValues, setValue]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
|
|
||||||
<Flex flexDirection="column" gap={2}>
|
|
||||||
<FormControl isInvalid={Boolean(errors.model_name)}>
|
|
||||||
<FormLabel>{t('modelManager.model')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('model_name', {
|
|
||||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<BaseModelSelect<CheckpointModelConfig> control={control} name="base_model" />
|
|
||||||
<FormControl isInvalid={Boolean(errors.path)}>
|
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
onBlur,
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
|
||||||
<Input {...register('description')} />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
|
||||||
<Input {...register('vae')} />
|
|
||||||
</FormControl>
|
|
||||||
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
|
|
||||||
<Flex flexDirection="column" width="100%" gap={2}>
|
|
||||||
{!useCustomConfig ? (
|
|
||||||
<CheckpointConfigsSelect control={control} name="config" />
|
|
||||||
) : (
|
|
||||||
<FormControl isRequired>
|
|
||||||
<FormLabel>{t('modelManager.config')}</FormLabel>
|
|
||||||
<Input {...register('config')} />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
|
|
||||||
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
|
|
||||||
</FormControl>
|
|
||||||
<Button mt={2} type="submit">
|
|
||||||
{t('modelManager.addModel')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const formStyles: CSSProperties = {
|
|
||||||
width: '100%',
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(AdvancedAddCheckpoint);
|
|
@ -1,148 +0,0 @@
|
|||||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
|
|
||||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
|
||||||
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import type { CSSProperties, FocusEventHandler } from 'react';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { DiffusersModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import { getModelName } from './util';
|
|
||||||
|
|
||||||
type AdvancedAddDiffusersProps = {
|
|
||||||
model_path?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
const AdvancedAddDiffusers = (props: AdvancedAddDiffusersProps) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { model_path } = props;
|
|
||||||
|
|
||||||
const [addMainModel] = useAddMainModelsMutation();
|
|
||||||
|
|
||||||
const {
|
|
||||||
register,
|
|
||||||
handleSubmit,
|
|
||||||
control,
|
|
||||||
getValues,
|
|
||||||
setValue,
|
|
||||||
formState: { errors },
|
|
||||||
reset,
|
|
||||||
} = useForm<DiffusersModelConfig>({
|
|
||||||
defaultValues: {
|
|
||||||
model_name: model_path ? getModelName(model_path, false) : '',
|
|
||||||
base_model: 'sd-1',
|
|
||||||
model_type: 'main',
|
|
||||||
path: model_path ? model_path : '',
|
|
||||||
description: '',
|
|
||||||
model_format: 'diffusers',
|
|
||||||
error: undefined,
|
|
||||||
vae: '',
|
|
||||||
variant: 'normal',
|
|
||||||
},
|
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
|
|
||||||
(values) => {
|
|
||||||
addMainModel({
|
|
||||||
body: values,
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelAdded', {
|
|
||||||
modelName: values.model_name,
|
|
||||||
}),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
reset();
|
|
||||||
// Close Advanced Panel in Scan Models tab
|
|
||||||
if (model_path) {
|
|
||||||
dispatch(setAdvancedAddScanModel(null));
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.modelAddFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[addMainModel, dispatch, model_path, reset, t]
|
|
||||||
);
|
|
||||||
|
|
||||||
const onBlur: FocusEventHandler<HTMLInputElement> = useCallback(
|
|
||||||
(e) => {
|
|
||||||
if (getValues().model_name === '') {
|
|
||||||
const modelName = getModelName(e.currentTarget.value, false);
|
|
||||||
if (modelName) {
|
|
||||||
setValue('model_name', modelName as string);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[getValues, setValue]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
|
|
||||||
<Flex flexDirection="column" gap={2}>
|
|
||||||
<FormControl isInvalid={Boolean(errors.model_name)}>
|
|
||||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('model_name', {
|
|
||||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<BaseModelSelect<DiffusersModelConfig> control={control} name="base_model" />
|
|
||||||
<FormControl isInvalid={Boolean(errors.path)}>
|
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
onBlur,
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
|
||||||
<Input {...register('description')} />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
|
||||||
<Input {...register('vae')} />
|
|
||||||
</FormControl>
|
|
||||||
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
|
|
||||||
|
|
||||||
<Button mt={2} type="submit">
|
|
||||||
{t('modelManager.addModel')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const formStyles: CSSProperties = {
|
|
||||||
width: '100%',
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(AdvancedAddDiffusers);
|
|
@ -1,50 +0,0 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { z } from 'zod';
|
|
||||||
|
|
||||||
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
|
|
||||||
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
|
|
||||||
|
|
||||||
export const zManualAddMode = z.enum(['diffusers', 'checkpoint']);
|
|
||||||
export type ManualAddMode = z.infer<typeof zManualAddMode>;
|
|
||||||
export const isManualAddMode = (v: unknown): v is ManualAddMode => zManualAddMode.safeParse(v).success;
|
|
||||||
|
|
||||||
const AdvancedAddModels = () => {
|
|
||||||
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const handleChange: ComboboxOnChange = useCallback((v) => {
|
|
||||||
if (!isManualAddMode(v?.value)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setAdvancedAddMode(v.value);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = useMemo(
|
|
||||||
() => [
|
|
||||||
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
|
|
||||||
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
|
|
||||||
],
|
|
||||||
[t]
|
|
||||||
);
|
|
||||||
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" gap={4} width="100%">
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
|
||||||
<Combobox value={value} options={options} onChange={handleChange} />
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<Flex p={4} borderRadius={4} bg="base.850">
|
|
||||||
{advancedAddMode === 'diffusers' && <AdvancedAddDiffusers />}
|
|
||||||
{advancedAddMode === 'checkpoint' && <AdvancedAddCheckpoint />}
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(AdvancedAddModels);
|
|
@ -1,176 +0,0 @@
|
|||||||
import { Button, Flex, FormControl, FormLabel, Input, Text } from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
|
||||||
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { difference, forEach, intersection, map, values } from 'lodash-es';
|
|
||||||
import type { ChangeEvent, MouseEvent } from 'react';
|
|
||||||
import { memo, useCallback, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
import type { SearchFolderResponse } from 'services/api/endpoints/models';
|
|
||||||
import {
|
|
||||||
useGetMainModelsQuery,
|
|
||||||
useGetModelsInFolderQuery,
|
|
||||||
useImportMainModelsMutation,
|
|
||||||
} from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
const FoundModelsList = () => {
|
|
||||||
const searchFolder = useAppSelector((s) => s.modelmanager.searchFolder);
|
|
||||||
const [nameFilter, setNameFilter] = useState<string>('');
|
|
||||||
|
|
||||||
// Get paths of models that are already installed
|
|
||||||
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
|
||||||
|
|
||||||
// Get all model paths from a given directory
|
|
||||||
const { foundModels, alreadyInstalled, filteredModels } = useGetModelsInFolderQuery(
|
|
||||||
{
|
|
||||||
search_path: searchFolder ? searchFolder : '',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
selectFromResult: ({ data }) => {
|
|
||||||
const installedModelValues = values(installedModels?.entities);
|
|
||||||
const installedModelPaths = map(installedModelValues, 'path');
|
|
||||||
// Only select models those that aren't already installed to Invoke
|
|
||||||
const notInstalledModels = difference(data, installedModelPaths);
|
|
||||||
const alreadyInstalled = intersection(data, installedModelPaths);
|
|
||||||
return {
|
|
||||||
foundModels: data,
|
|
||||||
alreadyInstalled: foundModelsFilter(alreadyInstalled, nameFilter),
|
|
||||||
filteredModels: foundModelsFilter(notInstalledModels, nameFilter),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const quickAddHandler = useCallback(
|
|
||||||
(e: MouseEvent<HTMLButtonElement>) => {
|
|
||||||
const model_name = e.currentTarget.id.split('\\').splice(-1)[0];
|
|
||||||
importMainModel({
|
|
||||||
body: {
|
|
||||||
location: e.currentTarget.id,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `Added Model: ${model_name}`,
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.modelAddFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dispatch, importMainModel, t]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
|
||||||
setNameFilter(e.target.value);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const handleClickSetAdvanced = useCallback((model: string) => dispatch(setAdvancedAddScanModel(model)), [dispatch]);
|
|
||||||
|
|
||||||
const renderModels = ({ models, showActions = true }: { models: string[]; showActions?: boolean }) => {
|
|
||||||
return models.map((model) => {
|
|
||||||
return (
|
|
||||||
<Flex key={model} p={4} gap={4} alignItems="center" borderRadius={4} bg="base.800">
|
|
||||||
<Flex w="full" minW="25%" flexDir="column">
|
|
||||||
<Text fontWeight="semibold">{model.split('\\').slice(-1)[0]}</Text>
|
|
||||||
<Text fontSize="sm" color="base.400">
|
|
||||||
{model}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
{showActions ? (
|
|
||||||
<Flex gap={2}>
|
|
||||||
<Button id={model} onClick={quickAddHandler} isLoading={isLoading}>
|
|
||||||
{t('modelManager.quickAdd')}
|
|
||||||
</Button>
|
|
||||||
<Button onClick={handleClickSetAdvanced.bind(null, model)} isLoading={isLoading}>
|
|
||||||
{t('modelManager.advanced')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
) : (
|
|
||||||
<Text fontWeight="semibold" p={2} borderRadius={4} color="invokeBlue.100" bg="invokeBlue.600">
|
|
||||||
{t('common.installed')}
|
|
||||||
</Text>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
const renderFoundModels = () => {
|
|
||||||
if (!searchFolder) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!foundModels || foundModels.length === 0) {
|
|
||||||
return (
|
|
||||||
<Flex w="full" h="full" justifyContent="center" alignItems="center" height={96} userSelect="none" bg="base.900">
|
|
||||||
<Text variant="subtext">{t('modelManager.noModels')}</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" gap={2} w="100%" minW="50%">
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.search')}</FormLabel>
|
|
||||||
<Input onChange={handleSearchFilter} />
|
|
||||||
</FormControl>
|
|
||||||
<Flex p={2} gap={2}>
|
|
||||||
<Text fontWeight="semibold">
|
|
||||||
{t('modelManager.modelsFound')}: {foundModels.length}
|
|
||||||
</Text>
|
|
||||||
<Text fontWeight="semibold" color="invokeBlue.200">
|
|
||||||
{t('common.notInstalled')}: {filteredModels.length}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<ScrollableContent>
|
|
||||||
<Flex gap={2} flexDirection="column">
|
|
||||||
{renderModels({ models: filteredModels })}
|
|
||||||
{renderModels({ models: alreadyInstalled, showActions: false })}
|
|
||||||
</Flex>
|
|
||||||
</ScrollableContent>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
return renderFoundModels();
|
|
||||||
};
|
|
||||||
|
|
||||||
const foundModelsFilter = (data: SearchFolderResponse | undefined, nameFilter: string) => {
|
|
||||||
const filteredModels: SearchFolderResponse = [];
|
|
||||||
forEach(data, (model) => {
|
|
||||||
if (!model) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model.includes(nameFilter)) {
|
|
||||||
filteredModels.push(model);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return filteredModels;
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(FoundModelsList);
|
|
@ -1,95 +0,0 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Box, Combobox, Flex, FormControl, FormLabel, IconButton, Text } from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
|
|
||||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { PiXBold } from 'react-icons/pi';
|
|
||||||
|
|
||||||
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
|
|
||||||
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
|
|
||||||
import type { ManualAddMode } from './AdvancedAddModels';
|
|
||||||
import { isManualAddMode } from './AdvancedAddModels';
|
|
||||||
|
|
||||||
const ScanAdvancedAddModels = () => {
|
|
||||||
const advancedAddScanModel = useAppSelector((s) => s.modelmanager.advancedAddScanModel);
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = useMemo(
|
|
||||||
() => [
|
|
||||||
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
|
|
||||||
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
|
|
||||||
],
|
|
||||||
[t]
|
|
||||||
);
|
|
||||||
|
|
||||||
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
|
|
||||||
|
|
||||||
const [isCheckpoint, setIsCheckpoint] = useState<boolean>(true);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
advancedAddScanModel && ['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) => advancedAddScanModel.endsWith(ext))
|
|
||||||
? setAdvancedAddMode('checkpoint')
|
|
||||||
: setAdvancedAddMode('diffusers');
|
|
||||||
}, [advancedAddScanModel, setAdvancedAddMode, isCheckpoint]);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const handleClickSetAdvanced = useCallback(() => dispatch(setAdvancedAddScanModel(null)), [dispatch]);
|
|
||||||
|
|
||||||
const handleChangeAddMode = useCallback<ComboboxOnChange>((v) => {
|
|
||||||
if (!isManualAddMode(v?.value)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setAdvancedAddMode(v.value);
|
|
||||||
if (v.value === 'checkpoint') {
|
|
||||||
setIsCheckpoint(true);
|
|
||||||
} else {
|
|
||||||
setIsCheckpoint(false);
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
|
|
||||||
|
|
||||||
if (!advancedAddScanModel) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Box
|
|
||||||
display="flex"
|
|
||||||
flexDirection="column"
|
|
||||||
minWidth="40%"
|
|
||||||
maxHeight="calc(100vh - 300px)"
|
|
||||||
overflow="scroll"
|
|
||||||
p={4}
|
|
||||||
gap={4}
|
|
||||||
borderRadius={4}
|
|
||||||
bg="base.800"
|
|
||||||
>
|
|
||||||
<Flex justifyContent="space-between" alignItems="center">
|
|
||||||
<Text size="xl" fontWeight="semibold">
|
|
||||||
{isCheckpoint || advancedAddMode === 'checkpoint' ? 'Add Checkpoint Model' : 'Add Diffusers Model'}
|
|
||||||
</Text>
|
|
||||||
<IconButton
|
|
||||||
icon={<PiXBold />}
|
|
||||||
aria-label={t('modelManager.closeAdvanced')}
|
|
||||||
onClick={handleClickSetAdvanced}
|
|
||||||
size="sm"
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
|
||||||
<Combobox value={value} options={options} onChange={handleChangeAddMode} />
|
|
||||||
</FormControl>
|
|
||||||
{isCheckpoint ? (
|
|
||||||
<AdvancedAddCheckpoint key={advancedAddScanModel} model_path={advancedAddScanModel} />
|
|
||||||
) : (
|
|
||||||
<AdvancedAddDiffusers key={advancedAddScanModel} model_path={advancedAddScanModel} />
|
|
||||||
)}
|
|
||||||
</Box>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ScanAdvancedAddModels);
|
|
@ -1,22 +0,0 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
|
||||||
import { memo } from 'react';
|
|
||||||
|
|
||||||
import FoundModelsList from './FoundModelsList';
|
|
||||||
import ScanAdvancedAddModels from './ScanAdvancedAddModels';
|
|
||||||
import SearchFolderForm from './SearchFolderForm';
|
|
||||||
|
|
||||||
const ScanModels = () => {
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" w="100%" h="full" gap={4}>
|
|
||||||
<SearchFolderForm />
|
|
||||||
<Flex gap={4}>
|
|
||||||
<Flex overflow="scroll" gap={4} w="100%" h="full">
|
|
||||||
<FoundModelsList />
|
|
||||||
</Flex>
|
|
||||||
<ScanAdvancedAddModels />
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ScanModels);
|
|
@ -1,103 +0,0 @@
|
|||||||
import { Flex, IconButton, Input, Text } from '@invoke-ai/ui-library';
|
|
||||||
import { useForm } from '@mantine/form';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { setAdvancedAddScanModel, setSearchFolder } from 'features/modelManager/store/modelManagerSlice';
|
|
||||||
import type { CSSProperties } from 'react';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { PiArrowsCounterClockwiseBold, PiMagnifyingGlassBold, PiTrashSimpleBold } from 'react-icons/pi';
|
|
||||||
import { useGetModelsInFolderQuery } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
type SearchFolderForm = {
|
|
||||||
folder: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
function SearchFolderForm() {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const searchFolder = useAppSelector((s) => s.modelmanager.searchFolder);
|
|
||||||
|
|
||||||
const { refetch: refetchFoundModels } = useGetModelsInFolderQuery({
|
|
||||||
search_path: searchFolder ? searchFolder : '',
|
|
||||||
});
|
|
||||||
|
|
||||||
const searchFolderForm = useForm<SearchFolderForm>({
|
|
||||||
initialValues: {
|
|
||||||
folder: '',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const searchFolderFormSubmitHandler = useCallback(
|
|
||||||
(values: SearchFolderForm) => {
|
|
||||||
dispatch(setSearchFolder(values.folder));
|
|
||||||
},
|
|
||||||
[dispatch]
|
|
||||||
);
|
|
||||||
|
|
||||||
const scanAgainHandler = useCallback(() => {
|
|
||||||
refetchFoundModels();
|
|
||||||
}, [refetchFoundModels]);
|
|
||||||
|
|
||||||
const handleClickClearCheckpointFolder = useCallback(() => {
|
|
||||||
dispatch(setSearchFolder(null));
|
|
||||||
dispatch(setAdvancedAddScanModel(null));
|
|
||||||
}, [dispatch]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<form onSubmit={searchFolderForm.onSubmit((values) => searchFolderFormSubmitHandler(values))} style={formStyles}>
|
|
||||||
<Flex w="100%" gap={2} borderRadius={4} alignItems="center">
|
|
||||||
<Flex w="100%" alignItems="center" gap={4} minH={12}>
|
|
||||||
<Text fontSize="sm" fontWeight="semibold" color="base.300" minW="max-content">
|
|
||||||
{t('common.folder')}
|
|
||||||
</Text>
|
|
||||||
{!searchFolder ? (
|
|
||||||
<Input w="100%" size="md" {...searchFolderForm.getInputProps('folder')} />
|
|
||||||
) : (
|
|
||||||
<Flex w="100%" p={2} px={4} bg="base.700" borderRadius={4} fontSize="sm" fontWeight="bold">
|
|
||||||
{searchFolder}
|
|
||||||
</Flex>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex gap={2}>
|
|
||||||
{!searchFolder ? (
|
|
||||||
<IconButton
|
|
||||||
aria-label={t('modelManager.findModels')}
|
|
||||||
tooltip={t('modelManager.findModels')}
|
|
||||||
icon={<PiMagnifyingGlassBold />}
|
|
||||||
fontSize={18}
|
|
||||||
size="sm"
|
|
||||||
type="submit"
|
|
||||||
/>
|
|
||||||
) : (
|
|
||||||
<IconButton
|
|
||||||
aria-label={t('modelManager.scanAgain')}
|
|
||||||
tooltip={t('modelManager.scanAgain')}
|
|
||||||
icon={<PiArrowsCounterClockwiseBold />}
|
|
||||||
onClick={scanAgainHandler}
|
|
||||||
fontSize={18}
|
|
||||||
size="sm"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<IconButton
|
|
||||||
aria-label={t('modelManager.clearCheckpointFolder')}
|
|
||||||
tooltip={t('modelManager.clearCheckpointFolder')}
|
|
||||||
icon={<PiTrashSimpleBold />}
|
|
||||||
size="sm"
|
|
||||||
onClick={handleClickClearCheckpointFolder}
|
|
||||||
isDisabled={!searchFolder}
|
|
||||||
colorScheme="red"
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export default memo(SearchFolderForm);
|
|
||||||
|
|
||||||
const formStyles: CSSProperties = {
|
|
||||||
width: '100%',
|
|
||||||
};
|
|
@ -1,92 +0,0 @@
|
|||||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Button, Combobox, Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
|
|
||||||
import { useForm } from '@mantine/form';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import type { CSSProperties } from 'react';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
|
||||||
{ label: 'None', value: 'none' },
|
|
||||||
{ label: 'v_prediction', value: 'v_prediction' },
|
|
||||||
{ label: 'epsilon', value: 'epsilon' },
|
|
||||||
{ label: 'sample', value: 'sample' },
|
|
||||||
];
|
|
||||||
|
|
||||||
type ExtendedImportModelConfig = {
|
|
||||||
location: string;
|
|
||||||
prediction_type?: 'v_prediction' | 'epsilon' | 'sample' | 'none' | undefined;
|
|
||||||
};
|
|
||||||
|
|
||||||
const SimpleAddModels = () => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
|
||||||
|
|
||||||
const addModelForm = useForm<ExtendedImportModelConfig>({
|
|
||||||
initialValues: {
|
|
||||||
location: '',
|
|
||||||
prediction_type: undefined,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
|
|
||||||
const importModelResponseBody = {
|
|
||||||
config: values.prediction_type === 'none' ? undefined : values.prediction_type,
|
|
||||||
};
|
|
||||||
|
|
||||||
importMainModel({ source: values.location, config: importModelResponseBody })
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.modelAddedSimple'),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
addModelForm.reset();
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `${error.data.detail} `,
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<form onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))} style={formStyles}>
|
|
||||||
<Flex flexDirection="column" width="100%" gap={4}>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
|
||||||
<Input placeholder={t('modelManager.simpleModelDesc')} w="100%" {...addModelForm.getInputProps('location')} />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
|
||||||
<Combobox options={options} defaultValue={options[0]} {...addModelForm.getInputProps('prediction_type')} />
|
|
||||||
</FormControl>
|
|
||||||
<Button type="submit" isLoading={isLoading}>
|
|
||||||
{t('modelManager.addModel')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const formStyles: CSSProperties = {
|
|
||||||
width: '100%',
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(SimpleAddModels);
|
|
@ -1,15 +0,0 @@
|
|||||||
export function getModelName(filepath: string, isCheckpoint: boolean = true) {
|
|
||||||
let regex;
|
|
||||||
if (isCheckpoint) {
|
|
||||||
regex = new RegExp('[^\\\\/]+(?=\\.)');
|
|
||||||
} else {
|
|
||||||
regex = new RegExp('[^\\\\/]+(?=[\\\\/]?$)');
|
|
||||||
}
|
|
||||||
|
|
||||||
const match = filepath.match(regex);
|
|
||||||
if (match) {
|
|
||||||
return match[0];
|
|
||||||
} else {
|
|
||||||
return '';
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,34 +0,0 @@
|
|||||||
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
|
|
||||||
import { memo, useCallback, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
import AddModels from './AddModelsPanel/AddModels';
|
|
||||||
import ScanModels from './AddModelsPanel/ScanModels';
|
|
||||||
|
|
||||||
type AddModelTabs = 'add' | 'scan';
|
|
||||||
|
|
||||||
const ImportModelsPanel = () => {
|
|
||||||
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const handleClickAddTab = useCallback(() => setAddModelTab('add'), []);
|
|
||||||
const handleClickScanTab = useCallback(() => setAddModelTab('scan'), []);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" gap={4} h="full">
|
|
||||||
<ButtonGroup>
|
|
||||||
<Button onClick={handleClickAddTab} isChecked={addModelTab === 'add'} size="sm" width="100%">
|
|
||||||
{t('modelManager.addModel')}
|
|
||||||
</Button>
|
|
||||||
<Button onClick={handleClickScanTab} isChecked={addModelTab === 'scan'} size="sm" width="100%">
|
|
||||||
{t('modelManager.scanForModels')}
|
|
||||||
</Button>
|
|
||||||
</ButtonGroup>
|
|
||||||
|
|
||||||
{addModelTab === 'add' && <AddModels />}
|
|
||||||
{addModelTab === 'scan' && <ScanModels />}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ImportModelsPanel);
|
|
@ -1,352 +0,0 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import {
|
|
||||||
Button,
|
|
||||||
Checkbox,
|
|
||||||
Combobox,
|
|
||||||
CompositeNumberInput,
|
|
||||||
CompositeSlider,
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormHelperText,
|
|
||||||
FormLabel,
|
|
||||||
Input,
|
|
||||||
Radio,
|
|
||||||
RadioGroup,
|
|
||||||
Text,
|
|
||||||
Tooltip,
|
|
||||||
} from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { pickBy } from 'lodash-es';
|
|
||||||
import type { ChangeEvent } from 'react';
|
|
||||||
import { memo, useCallback, useMemo, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
import { useGetMainModelsQuery, useMergeMainModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { BaseModelType, MergeModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const baseModelTypeSelectOptions: ComboboxOption[] = [
|
|
||||||
{ label: 'Stable Diffusion 1', value: 'sd-1' },
|
|
||||||
{ label: 'Stable Diffusion 2', value: 'sd-2' },
|
|
||||||
];
|
|
||||||
|
|
||||||
type MergeInterpolationMethods = 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
|
|
||||||
|
|
||||||
const MergeModelsPanel = () => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
|
||||||
|
|
||||||
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
|
||||||
|
|
||||||
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
|
|
||||||
const valueBaseModel = useMemo(() => baseModelTypeSelectOptions.find((o) => o.value === baseModel), [baseModel]);
|
|
||||||
const sd1DiffusersModels = pickBy(
|
|
||||||
data?.entities,
|
|
||||||
(value, _) => value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
|
|
||||||
);
|
|
||||||
|
|
||||||
const sd2DiffusersModels = pickBy(
|
|
||||||
data?.entities,
|
|
||||||
(value, _) => value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
|
|
||||||
);
|
|
||||||
|
|
||||||
const modelsMap = useMemo(() => {
|
|
||||||
return {
|
|
||||||
'sd-1': sd1DiffusersModels,
|
|
||||||
'sd-2': sd2DiffusersModels,
|
|
||||||
};
|
|
||||||
}, [sd1DiffusersModels, sd2DiffusersModels]);
|
|
||||||
|
|
||||||
const [modelOne, setModelOne] = useState<string | null>(
|
|
||||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])?.[0] ?? null
|
|
||||||
);
|
|
||||||
const [modelTwo, setModelTwo] = useState<string | null>(
|
|
||||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])?.[1] ?? null
|
|
||||||
);
|
|
||||||
const [modelThree, setModelThree] = useState<string | null>(null);
|
|
||||||
|
|
||||||
const [mergedModelName, setMergedModelName] = useState<string>('');
|
|
||||||
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
|
|
||||||
|
|
||||||
const [modelMergeInterp, setModelMergeInterp] = useState<MergeInterpolationMethods>('weighted_sum');
|
|
||||||
|
|
||||||
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<'root' | 'custom'>('root');
|
|
||||||
|
|
||||||
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = useState<string>('');
|
|
||||||
|
|
||||||
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
|
|
||||||
|
|
||||||
const optionsModelOne = useMemo(
|
|
||||||
() =>
|
|
||||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])
|
|
||||||
.filter((model) => model !== modelTwo && model !== modelThree)
|
|
||||||
.map((model) => ({ label: model, value: model })),
|
|
||||||
[modelsMap, baseModel, modelTwo, modelThree]
|
|
||||||
);
|
|
||||||
|
|
||||||
const optionsModelTwo = useMemo(
|
|
||||||
() =>
|
|
||||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])
|
|
||||||
.filter((model) => model !== modelOne && model !== modelThree)
|
|
||||||
.map((model) => ({ label: model, value: model })),
|
|
||||||
[modelsMap, baseModel, modelOne, modelThree]
|
|
||||||
);
|
|
||||||
|
|
||||||
const optionsModelThree = useMemo(
|
|
||||||
() =>
|
|
||||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])
|
|
||||||
.filter((model) => model !== modelOne && model !== modelTwo)
|
|
||||||
.map((model) => ({ label: model, value: model })),
|
|
||||||
[modelsMap, baseModel, modelOne, modelTwo]
|
|
||||||
);
|
|
||||||
|
|
||||||
const onChangeBaseModel = useCallback<ComboboxOnChange>((v) => {
|
|
||||||
if (!v) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (!(v.value === 'sd-1' || v.value === 'sd-2')) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setBaseModel(v.value);
|
|
||||||
setModelOne(null);
|
|
||||||
setModelTwo(null);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const onChangeModelOne = useCallback<ComboboxOnChange>((v) => {
|
|
||||||
if (!v) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setModelOne(v.value);
|
|
||||||
}, []);
|
|
||||||
const onChangeModelTwo = useCallback<ComboboxOnChange>((v) => {
|
|
||||||
if (!v) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setModelTwo(v.value);
|
|
||||||
}, []);
|
|
||||||
const onChangeModelThree = useCallback<ComboboxOnChange>((v) => {
|
|
||||||
if (!v) {
|
|
||||||
setModelThree(null);
|
|
||||||
setModelMergeInterp('add_difference');
|
|
||||||
} else {
|
|
||||||
setModelThree(v.value);
|
|
||||||
setModelMergeInterp('weighted_sum');
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const valueModelOne = useMemo(() => optionsModelOne.find((o) => o.value === modelOne), [modelOne, optionsModelOne]);
|
|
||||||
const valueModelTwo = useMemo(() => optionsModelTwo.find((o) => o.value === modelTwo), [modelTwo, optionsModelTwo]);
|
|
||||||
const valueModelThree = useMemo(
|
|
||||||
() => optionsModelThree.find((o) => o.value === modelThree),
|
|
||||||
[modelThree, optionsModelThree]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleChangeMergedModelName = useCallback(
|
|
||||||
(e: ChangeEvent<HTMLInputElement>) => setMergedModelName(e.target.value),
|
|
||||||
[]
|
|
||||||
);
|
|
||||||
const handleChangeModelMergeAlpha = useCallback((v: number) => setModelMergeAlpha(v), []);
|
|
||||||
const handleResetModelMergeAlpha = useCallback(() => setModelMergeAlpha(0.5), []);
|
|
||||||
const handleChangeMergeInterp = useCallback((v: MergeInterpolationMethods) => setModelMergeInterp(v), []);
|
|
||||||
const handleChangeMergeSaveLocType = useCallback((v: 'root' | 'custom') => setModelMergeSaveLocType(v), []);
|
|
||||||
const handleChangeMergeCustomSaveLoc = useCallback(
|
|
||||||
(e: ChangeEvent<HTMLInputElement>) => setModelMergeCustomSaveLoc(e.target.value),
|
|
||||||
[]
|
|
||||||
);
|
|
||||||
const handleChangeModelMergeForce = useCallback(
|
|
||||||
(e: ChangeEvent<HTMLInputElement>) => setModelMergeForce(e.target.checked),
|
|
||||||
[]
|
|
||||||
);
|
|
||||||
|
|
||||||
const mergeModelsHandler = useCallback(() => {
|
|
||||||
const models_names: string[] = [];
|
|
||||||
|
|
||||||
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
|
|
||||||
modelsToMerge = modelsToMerge.filter((model) => model !== null);
|
|
||||||
modelsToMerge.forEach((model) => {
|
|
||||||
const n = model?.split('/')?.[2];
|
|
||||||
if (n) {
|
|
||||||
models_names.push(n);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
const mergeModelsInfo: MergeModelConfig['body'] = {
|
|
||||||
model_names: models_names,
|
|
||||||
merged_model_name: mergedModelName !== '' ? mergedModelName : models_names.join('-'),
|
|
||||||
alpha: modelMergeAlpha,
|
|
||||||
interp: modelMergeInterp,
|
|
||||||
force: modelMergeForce,
|
|
||||||
merge_dest_directory: modelMergeSaveLocType === 'root' ? undefined : modelMergeCustomSaveLoc,
|
|
||||||
};
|
|
||||||
|
|
||||||
mergeModels({
|
|
||||||
base_model: baseModel,
|
|
||||||
body: { body: mergeModelsInfo },
|
|
||||||
})
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelsMerged'),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelsMergeFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}, [
|
|
||||||
baseModel,
|
|
||||||
dispatch,
|
|
||||||
mergeModels,
|
|
||||||
mergedModelName,
|
|
||||||
modelMergeAlpha,
|
|
||||||
modelMergeCustomSaveLoc,
|
|
||||||
modelMergeForce,
|
|
||||||
modelMergeInterp,
|
|
||||||
modelMergeSaveLocType,
|
|
||||||
modelOne,
|
|
||||||
modelThree,
|
|
||||||
modelTwo,
|
|
||||||
t,
|
|
||||||
]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDir="column" gap={4}>
|
|
||||||
<Flex flexDir="column" gap={1}>
|
|
||||||
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
|
|
||||||
<Text fontSize="sm" variant="subtext">
|
|
||||||
{t('modelManager.modelMergeHeaderHelp2')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<FormControl w="full">
|
|
||||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
|
||||||
<Combobox options={baseModelTypeSelectOptions} value={valueBaseModel} onChange={onChangeBaseModel} />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl w="full">
|
|
||||||
<FormLabel>{t('modelManager.modelOne')}</FormLabel>
|
|
||||||
<Combobox options={optionsModelOne} value={valueModelOne} onChange={onChangeModelOne} />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl w="full">
|
|
||||||
<FormLabel>{t('modelManager.modelTwo')}</FormLabel>
|
|
||||||
<Combobox options={optionsModelTwo} value={valueModelTwo} onChange={onChangeModelTwo} />
|
|
||||||
</FormControl>
|
|
||||||
<FormControl w="full">
|
|
||||||
<FormLabel>{t('modelManager.modelThree')}</FormLabel>
|
|
||||||
<Combobox options={optionsModelThree} value={valueModelThree} onChange={onChangeModelThree} isClearable />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.mergedModelName')}</FormLabel>
|
|
||||||
<Input value={mergedModelName} onChange={handleChangeMergedModelName} />
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<Flex flexDirection="column" padding={4} borderRadius="base" gap={4} bg="base.800">
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.alpha')}</FormLabel>
|
|
||||||
<CompositeSlider
|
|
||||||
min={0.01}
|
|
||||||
max={0.99}
|
|
||||||
step={0.01}
|
|
||||||
value={modelMergeAlpha}
|
|
||||||
onChange={handleChangeModelMergeAlpha}
|
|
||||||
onReset={handleResetModelMergeAlpha}
|
|
||||||
marks
|
|
||||||
/>
|
|
||||||
<CompositeNumberInput
|
|
||||||
min={0.01}
|
|
||||||
max={0.99}
|
|
||||||
step={0.01}
|
|
||||||
value={modelMergeAlpha}
|
|
||||||
onChange={handleChangeModelMergeAlpha}
|
|
||||||
onReset={handleResetModelMergeAlpha}
|
|
||||||
/>
|
|
||||||
<FormHelperText>{t('modelManager.modelMergeAlphaHelp')}</FormHelperText>
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex padding={4} gap={4} borderRadius="base" bg="base.800">
|
|
||||||
<Text fontSize="sm" variant="subtext">
|
|
||||||
{t('modelManager.interpolationType')}
|
|
||||||
</Text>
|
|
||||||
<RadioGroup value={modelMergeInterp} onChange={handleChangeMergeInterp}>
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
{modelThree === null ? (
|
|
||||||
<>
|
|
||||||
<Radio value="weighted_sum">
|
|
||||||
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
|
|
||||||
</Radio>
|
|
||||||
<Radio value="sigmoid">
|
|
||||||
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
|
|
||||||
</Radio>
|
|
||||||
<Radio value="inv_sigmoid">
|
|
||||||
<Text fontSize="sm">{t('modelManager.inverseSigmoid')}</Text>
|
|
||||||
</Radio>
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<Radio value="add_difference">
|
|
||||||
<Tooltip label={t('modelManager.modelMergeInterpAddDifferenceHelp')}>
|
|
||||||
<Text fontSize="sm">{t('modelManager.addDifference')}</Text>
|
|
||||||
</Tooltip>
|
|
||||||
</Radio>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
</RadioGroup>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex flexDirection="column" padding={4} borderRadius="base" gap={4} bg="base.900">
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<Text fontSize="sm" variant="subtext">
|
|
||||||
{t('modelManager.mergedModelSaveLocation')}
|
|
||||||
</Text>
|
|
||||||
<RadioGroup value={modelMergeSaveLocType} onChange={handleChangeMergeSaveLocType}>
|
|
||||||
<Flex columnGap={4}>
|
|
||||||
<Radio value="root">
|
|
||||||
<Text fontSize="sm">{t('modelManager.invokeAIFolder')}</Text>
|
|
||||||
</Radio>
|
|
||||||
|
|
||||||
<Radio value="custom">
|
|
||||||
<Text fontSize="sm">{t('modelManager.custom')}</Text>
|
|
||||||
</Radio>
|
|
||||||
</Flex>
|
|
||||||
</RadioGroup>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
{modelMergeSaveLocType === 'custom' && (
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.mergedModelCustomSaveLocation')}</FormLabel>
|
|
||||||
<Input value={modelMergeCustomSaveLoc} onChange={handleChangeMergeCustomSaveLoc} />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.ignoreMismatch')}</FormLabel>
|
|
||||||
<Checkbox isChecked={modelMergeForce} onChange={handleChangeModelMergeForce} />
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<Button onClick={mergeModelsHandler} isLoading={isLoading} isDisabled={modelOne === null || modelTwo === null}>
|
|
||||||
{t('modelManager.merge')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(MergeModelsPanel);
|
|
@ -1,63 +0,0 @@
|
|||||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
|
||||||
import { memo, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
|
||||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
|
||||||
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
|
|
||||||
import ModelList from './ModelManagerPanel/ModelList';
|
|
||||||
|
|
||||||
const ModelManagerPanel = () => {
|
|
||||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
|
||||||
const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
|
||||||
selectFromResult: ({ data }) => ({
|
|
||||||
mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
const { loraModel } = useGetLoRAModelsQuery(undefined, {
|
|
||||||
selectFromResult: ({ data }) => ({
|
|
||||||
loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const model = mainModel ? mainModel : loraModel;
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex gap={8} w="full" h="full">
|
|
||||||
<ModelList selectedModelId={selectedModelId} setSelectedModelId={setSelectedModelId} />
|
|
||||||
<ModelEdit model={model} />
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
type ModelEditProps = {
|
|
||||||
model: MainModelConfig | LoRAConfig | undefined;
|
|
||||||
};
|
|
||||||
|
|
||||||
const ModelEdit = (props: ModelEditProps) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const { model } = props;
|
|
||||||
|
|
||||||
if (model?.format === 'checkpoint') {
|
|
||||||
return <CheckpointModelEdit key={model.key} model={model} />;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model?.format === 'diffusers') {
|
|
||||||
return <DiffusersModelEdit key={model.key} model={model as DiffusersModelConfig} />;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model?.type === 'lora') {
|
|
||||||
return <LoRAModelEdit key={model.key} model={model} />;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex w="full" h="full" justifyContent="center" alignItems="center" maxH={96} userSelect="none">
|
|
||||||
<Text variant="subtext">{t('modelManager.noModelSelected')}</Text>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ModelManagerPanel);
|
|
@ -1,185 +0,0 @@
|
|||||||
import {
|
|
||||||
Badge,
|
|
||||||
Button,
|
|
||||||
Checkbox,
|
|
||||||
Divider,
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormErrorMessage,
|
|
||||||
FormLabel,
|
|
||||||
Input,
|
|
||||||
Text,
|
|
||||||
} from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
|
||||||
import CheckpointConfigsSelect from 'features/modelManager/subpanels/shared/CheckpointConfigsSelect';
|
|
||||||
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
|
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { memo, useCallback, useEffect, useState } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { CheckpointModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import ModelConvert from './ModelConvert';
|
|
||||||
|
|
||||||
type CheckpointModelEditProps = {
|
|
||||||
model: CheckpointModelConfig;
|
|
||||||
};
|
|
||||||
|
|
||||||
const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
|
|
||||||
const { model } = props;
|
|
||||||
|
|
||||||
const [updateModel, { isLoading }] = useUpdateModelsMutation();
|
|
||||||
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
|
|
||||||
|
|
||||||
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (!availableCheckpointConfigs?.includes(model.config)) {
|
|
||||||
setUseCustomConfig(true);
|
|
||||||
}
|
|
||||||
}, [availableCheckpointConfigs, model.config]);
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const {
|
|
||||||
register,
|
|
||||||
handleSubmit,
|
|
||||||
control,
|
|
||||||
formState: { errors },
|
|
||||||
reset,
|
|
||||||
} = useForm<CheckpointModelConfig>({
|
|
||||||
defaultValues: {
|
|
||||||
name: model.name ? model.name : '',
|
|
||||||
base: model.base,
|
|
||||||
type: 'main',
|
|
||||||
path: model.path ? model.path : '',
|
|
||||||
description: model.description ? model.description : '',
|
|
||||||
format: 'checkpoint',
|
|
||||||
vae: model.vae ? model.vae : '',
|
|
||||||
config: model.config ? model.config : '',
|
|
||||||
variant: model.variant,
|
|
||||||
},
|
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
|
|
||||||
(values) => {
|
|
||||||
const responseBody = {
|
|
||||||
key: model.key,
|
|
||||||
body: values,
|
|
||||||
};
|
|
||||||
updateModel(responseBody)
|
|
||||||
.unwrap()
|
|
||||||
.then((payload) => {
|
|
||||||
reset(payload as CheckpointModelConfig, { keepDefaultValues: true });
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelUpdated'),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch((_) => {
|
|
||||||
reset();
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelUpdateFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dispatch, model.key, reset, t, updateModel]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
|
||||||
<Flex justifyContent="space-between" alignItems="center">
|
|
||||||
<Flex flexDirection="column">
|
|
||||||
<Text fontSize="lg" fontWeight="bold">
|
|
||||||
{model.name}
|
|
||||||
</Text>
|
|
||||||
<Text fontSize="sm" color="base.400">
|
|
||||||
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
{![''].includes(model.base) ? (
|
|
||||||
<ModelConvert model={model} />
|
|
||||||
) : (
|
|
||||||
<Badge p={2} borderRadius={4} bg="error.400">
|
|
||||||
{t('modelManager.conversionNotSupported')}
|
|
||||||
</Badge>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
<Divider />
|
|
||||||
|
|
||||||
<Flex flexDirection="column" maxHeight={window.innerHeight - 270} overflowY="scroll">
|
|
||||||
<form onSubmit={handleSubmit(onSubmit)}>
|
|
||||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
|
||||||
<FormControl isInvalid={Boolean(errors.name)}>
|
|
||||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('name', {
|
|
||||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
|
||||||
<Input {...register('description')} />
|
|
||||||
</FormControl>
|
|
||||||
<BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
|
|
||||||
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
|
|
||||||
<FormControl isInvalid={Boolean(errors.path)}>
|
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
|
||||||
<Input {...register('vae')} />
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<Flex flexDirection="column" gap={2}>
|
|
||||||
{!useCustomConfig ? (
|
|
||||||
<CheckpointConfigsSelect control={control} name="config" />
|
|
||||||
) : (
|
|
||||||
<FormControl isRequired>
|
|
||||||
<FormLabel>{t('modelManager.config')}</FormLabel>
|
|
||||||
<Input {...register('config')} />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
|
|
||||||
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
|
|
||||||
</FormControl>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Button type="submit" isLoading={isLoading}>
|
|
||||||
{t('modelManager.updateModel')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(CheckpointModelEdit);
|
|
@ -1,133 +0,0 @@
|
|||||||
import { Button, Divider, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text } from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
|
||||||
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
|
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { DiffusersModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
type DiffusersModelEditProps = {
|
|
||||||
model: DiffusersModelConfig;
|
|
||||||
};
|
|
||||||
|
|
||||||
const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
|
|
||||||
const { model } = props;
|
|
||||||
|
|
||||||
const [updateModel, { isLoading }] = useUpdateModelsMutation();
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const {
|
|
||||||
register,
|
|
||||||
handleSubmit,
|
|
||||||
control,
|
|
||||||
formState: { errors },
|
|
||||||
reset,
|
|
||||||
} = useForm<DiffusersModelConfig>({
|
|
||||||
defaultValues: {
|
|
||||||
name: model.name ? model.name : '',
|
|
||||||
base: model.base,
|
|
||||||
type: 'main',
|
|
||||||
path: model.path ? model.path : '',
|
|
||||||
description: model.description ? model.description : '',
|
|
||||||
format: 'diffusers',
|
|
||||||
vae: model.vae ? model.vae : '',
|
|
||||||
variant: model.variant,
|
|
||||||
},
|
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
|
|
||||||
(values) => {
|
|
||||||
const responseBody = {
|
|
||||||
key: model.key,
|
|
||||||
body: values,
|
|
||||||
};
|
|
||||||
|
|
||||||
updateModel(responseBody)
|
|
||||||
.unwrap()
|
|
||||||
.then((payload) => {
|
|
||||||
reset(payload as DiffusersModelConfig, { keepDefaultValues: true });
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelUpdated'),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch((_) => {
|
|
||||||
reset();
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelUpdateFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dispatch, model.key, reset, t, updateModel]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
|
||||||
<Flex flexDirection="column">
|
|
||||||
<Text fontSize="lg" fontWeight="bold">
|
|
||||||
{model.name}
|
|
||||||
</Text>
|
|
||||||
<Text fontSize="sm" color="base.400">
|
|
||||||
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
<Divider />
|
|
||||||
|
|
||||||
<form onSubmit={handleSubmit(onSubmit)}>
|
|
||||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
|
||||||
<FormControl isInvalid={Boolean(errors.name)}>
|
|
||||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('name', {
|
|
||||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
|
||||||
<Input {...register('description')} />
|
|
||||||
</FormControl>
|
|
||||||
<BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
|
|
||||||
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
|
|
||||||
<FormControl isInvalid={Boolean(errors.path)}>
|
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
|
||||||
<Input {...register('vae')} />
|
|
||||||
</FormControl>
|
|
||||||
<Button type="submit" isLoading={isLoading}>
|
|
||||||
{t('modelManager.updateModel')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(DiffusersModelEdit);
|
|
@ -1,127 +0,0 @@
|
|||||||
import { Button, Divider, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text } from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
|
||||||
import { LORA_MODEL_FORMAT_MAP, MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
import type { SubmitHandler } from 'react-hook-form';
|
|
||||||
import { useForm } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { LoRAModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
type LoRAModelEditProps = {
|
|
||||||
model: LoRAModelConfig;
|
|
||||||
};
|
|
||||||
|
|
||||||
const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
|
||||||
const { model } = props;
|
|
||||||
|
|
||||||
const [updateModel, { isLoading }] = useUpdateModelsMutation();
|
|
||||||
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const {
|
|
||||||
register,
|
|
||||||
handleSubmit,
|
|
||||||
control,
|
|
||||||
formState: { errors },
|
|
||||||
reset,
|
|
||||||
} = useForm<LoRAModelConfig>({
|
|
||||||
defaultValues: {
|
|
||||||
name: model.name ? model.name : '',
|
|
||||||
base: model.base,
|
|
||||||
type: 'lora',
|
|
||||||
path: model.path ? model.path : '',
|
|
||||||
description: model.description ? model.description : '',
|
|
||||||
format: model.format,
|
|
||||||
},
|
|
||||||
mode: 'onChange',
|
|
||||||
});
|
|
||||||
|
|
||||||
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
|
|
||||||
(values) => {
|
|
||||||
const responseBody = {
|
|
||||||
key: model.key,
|
|
||||||
body: values,
|
|
||||||
};
|
|
||||||
|
|
||||||
updateModel(responseBody)
|
|
||||||
.unwrap()
|
|
||||||
.then((payload) => {
|
|
||||||
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelUpdated'),
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch((_) => {
|
|
||||||
reset();
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.modelUpdateFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[dispatch, model.key, reset, t, updateModel]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
|
||||||
<Flex flexDirection="column">
|
|
||||||
<Text fontSize="lg" fontWeight="bold">
|
|
||||||
{model.name}
|
|
||||||
</Text>
|
|
||||||
<Text fontSize="sm" color="base.400">
|
|
||||||
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')} ⋅ {LORA_MODEL_FORMAT_MAP[model.format]}{' '}
|
|
||||||
{t('common.format')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
<Divider />
|
|
||||||
|
|
||||||
<form onSubmit={handleSubmit(onSubmit)}>
|
|
||||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
|
||||||
<FormControl isInvalid={Boolean(errors.name)}>
|
|
||||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('name', {
|
|
||||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
|
||||||
<Input {...register('description')} />
|
|
||||||
</FormControl>
|
|
||||||
<BaseModelSelect<LoRAModelConfig> control={control} name="base" />
|
|
||||||
|
|
||||||
<FormControl isInvalid={Boolean(errors.path)}>
|
|
||||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
|
||||||
<Input
|
|
||||||
{...register('path', {
|
|
||||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
|
||||||
})}
|
|
||||||
/>
|
|
||||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
|
||||||
</FormControl>
|
|
||||||
<Button type="submit" isLoading={isLoading}>
|
|
||||||
{t('modelManager.updateModel')}
|
|
||||||
</Button>
|
|
||||||
</Flex>
|
|
||||||
</form>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(LoRAModelEdit);
|
|
@ -1,165 +0,0 @@
|
|||||||
import {
|
|
||||||
Button,
|
|
||||||
ConfirmationAlertDialog,
|
|
||||||
Flex,
|
|
||||||
FormControl,
|
|
||||||
FormLabel,
|
|
||||||
Input,
|
|
||||||
ListItem,
|
|
||||||
Radio,
|
|
||||||
RadioGroup,
|
|
||||||
Text,
|
|
||||||
Tooltip,
|
|
||||||
UnorderedList,
|
|
||||||
useDisclosure,
|
|
||||||
} from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import type { ChangeEvent } from 'react';
|
|
||||||
import { memo, useCallback, useEffect, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { CheckpointModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
interface ModelConvertProps {
|
|
||||||
model: CheckpointModelConfig;
|
|
||||||
}
|
|
||||||
|
|
||||||
type SaveLocation = 'InvokeAIRoot' | 'Custom';
|
|
||||||
|
|
||||||
const ModelConvert = (props: ModelConvertProps) => {
|
|
||||||
const { model } = props;
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
const [saveLocation, setSaveLocation] = useState<SaveLocation>('InvokeAIRoot');
|
|
||||||
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
setSaveLocation('InvokeAIRoot');
|
|
||||||
}, [model]);
|
|
||||||
|
|
||||||
const modelConvertCancelHandler = useCallback(() => {
|
|
||||||
setSaveLocation('InvokeAIRoot');
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const handleChangeSaveLocation = useCallback((v: string) => {
|
|
||||||
setSaveLocation(v as SaveLocation);
|
|
||||||
}, []);
|
|
||||||
const handleChangeCustomSaveLocation = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
|
||||||
setCustomSaveLocation(e.target.value);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const modelConvertHandler = useCallback(() => {
|
|
||||||
const queryArg = {
|
|
||||||
base_model: model.base,
|
|
||||||
model_name: model.name,
|
|
||||||
convert_dest_directory: saveLocation === 'Custom' ? customSaveLocation : undefined,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (saveLocation === 'Custom' && customSaveLocation === '') {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('modelManager.noCustomLocationProvided'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `${t('modelManager.convertingModelBegin')}: ${model.name}`,
|
|
||||||
status: 'info',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
convertModel(queryArg)
|
|
||||||
.unwrap()
|
|
||||||
.then(() => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `${t('modelManager.modelConverted')}: ${model.name}`,
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch(() => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `${t('modelManager.modelConversionFailed')}: ${model.name}`,
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}, [convertModel, customSaveLocation, dispatch, model.base, model.name, saveLocation, t]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
<Button
|
|
||||||
onClick={onOpen}
|
|
||||||
size="sm"
|
|
||||||
aria-label={t('modelManager.convertToDiffusers')}
|
|
||||||
className=" modal-close-btn"
|
|
||||||
isLoading={isLoading}
|
|
||||||
>
|
|
||||||
🧨 {t('modelManager.convertToDiffusers')}
|
|
||||||
</Button>
|
|
||||||
<ConfirmationAlertDialog
|
|
||||||
title={`${t('modelManager.convert')} ${model.name}`}
|
|
||||||
acceptCallback={modelConvertHandler}
|
|
||||||
cancelCallback={modelConvertCancelHandler}
|
|
||||||
acceptButtonText={`${t('modelManager.convert')}`}
|
|
||||||
isOpen={isOpen}
|
|
||||||
onClose={onClose}
|
|
||||||
>
|
|
||||||
<Flex flexDirection="column" rowGap={4}>
|
|
||||||
<Text>{t('modelManager.convertToDiffusersHelpText1')}</Text>
|
|
||||||
<UnorderedList>
|
|
||||||
<ListItem>{t('modelManager.convertToDiffusersHelpText2')}</ListItem>
|
|
||||||
<ListItem>{t('modelManager.convertToDiffusersHelpText3')}</ListItem>
|
|
||||||
<ListItem>{t('modelManager.convertToDiffusersHelpText4')}</ListItem>
|
|
||||||
<ListItem>{t('modelManager.convertToDiffusersHelpText5')}</ListItem>
|
|
||||||
</UnorderedList>
|
|
||||||
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
|
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Flex flexDir="column" gap={2}>
|
|
||||||
<Flex marginTop={4} flexDir="column" gap={2}>
|
|
||||||
<Text fontWeight="semibold">{t('modelManager.convertToDiffusersSaveLocation')}</Text>
|
|
||||||
<RadioGroup value={saveLocation} onChange={handleChangeSaveLocation}>
|
|
||||||
<Flex gap={4}>
|
|
||||||
<Radio value="InvokeAIRoot">
|
|
||||||
<Tooltip label="Save converted model in the InvokeAI root folder">
|
|
||||||
{t('modelManager.invokeRoot')}
|
|
||||||
</Tooltip>
|
|
||||||
</Radio>
|
|
||||||
<Radio value="Custom">
|
|
||||||
<Tooltip label="Save converted model in a custom folder">{t('modelManager.custom')}</Tooltip>
|
|
||||||
</Radio>
|
|
||||||
</Flex>
|
|
||||||
</RadioGroup>
|
|
||||||
</Flex>
|
|
||||||
{saveLocation === 'Custom' && (
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.customSaveLocation')}</FormLabel>
|
|
||||||
<Input width="full" value={customSaveLocation} onChange={handleChangeCustomSaveLocation} />
|
|
||||||
</FormControl>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
</ConfirmationAlertDialog>
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ModelConvert);
|
|
@ -1,205 +0,0 @@
|
|||||||
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Input, Spinner, Text } from '@invoke-ai/ui-library';
|
|
||||||
import type { EntityState } from '@reduxjs/toolkit';
|
|
||||||
import { forEach } from 'lodash-es';
|
|
||||||
import type { ChangeEvent, PropsWithChildren } from 'react';
|
|
||||||
import { memo, useCallback, useState } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
|
||||||
// import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
|
||||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
import ModelListItem from './ModelListItem';
|
|
||||||
|
|
||||||
type ModelListProps = {
|
|
||||||
selectedModelId: string | undefined;
|
|
||||||
setSelectedModelId: (name: string | undefined) => void;
|
|
||||||
};
|
|
||||||
|
|
||||||
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
|
|
||||||
|
|
||||||
type ModelType = 'main' | 'lora';
|
|
||||||
|
|
||||||
type CombinedModelFormat = ModelFormat | 'lora';
|
|
||||||
|
|
||||||
const ModelList = (props: ModelListProps) => {
|
|
||||||
const { selectedModelId, setSelectedModelId } = props;
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const [nameFilter, setNameFilter] = useState<string>('');
|
|
||||||
const [modelFormatFilter, setModelFormatFilter] = useState<CombinedModelFormat>('all');
|
|
||||||
|
|
||||||
const { filteredDiffusersModels, isLoadingDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
|
||||||
filteredDiffusersModels: modelsFilter(data, 'main', 'diffusers', nameFilter),
|
|
||||||
isLoadingDiffusersModels: isLoading,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredCheckpointModels, isLoadingCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
|
||||||
filteredCheckpointModels: modelsFilter(data, 'main', 'checkpoint', nameFilter),
|
|
||||||
isLoadingCheckpointModels: isLoading,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
|
|
||||||
selectFromResult: ({ data, isLoading }) => ({
|
|
||||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
|
||||||
isLoadingLoraModels: isLoading,
|
|
||||||
}),
|
|
||||||
});
|
|
||||||
|
|
||||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
|
||||||
setNameFilter(e.target.value);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
|
||||||
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
|
||||||
<ButtonGroup>
|
|
||||||
<Button onClick={setModelFormatFilter.bind(null, 'all')} isChecked={modelFormatFilter === 'all'} size="sm">
|
|
||||||
{t('modelManager.allModels')}
|
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
onClick={setModelFormatFilter.bind(null, 'diffusers')}
|
|
||||||
isChecked={modelFormatFilter === 'diffusers'}
|
|
||||||
>
|
|
||||||
{t('modelManager.diffusersModels')}
|
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
size="sm"
|
|
||||||
onClick={setModelFormatFilter.bind(null, 'checkpoint')}
|
|
||||||
isChecked={modelFormatFilter === 'checkpoint'}
|
|
||||||
>
|
|
||||||
{t('modelManager.checkpointModels')}
|
|
||||||
</Button>
|
|
||||||
<Button size="sm" onClick={setModelFormatFilter.bind(null, 'lora')} isChecked={modelFormatFilter === 'lora'}>
|
|
||||||
{t('modelManager.loraModels')}
|
|
||||||
</Button>
|
|
||||||
</ButtonGroup>
|
|
||||||
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.search')}</FormLabel>
|
|
||||||
<Input onChange={handleSearchFilter} />
|
|
||||||
</FormControl>
|
|
||||||
|
|
||||||
<Flex flexDirection="column" gap={4} maxHeight={window.innerHeight - 280} overflow="scroll">
|
|
||||||
{/* Diffusers List */}
|
|
||||||
{isLoadingDiffusersModels && <FetchingModelsLoader loadingMessage="Loading Diffusers..." />}
|
|
||||||
{['all', 'diffusers'].includes(modelFormatFilter) &&
|
|
||||||
!isLoadingDiffusersModels &&
|
|
||||||
filteredDiffusersModels.length > 0 && (
|
|
||||||
<ModelListWrapper
|
|
||||||
title="Diffusers"
|
|
||||||
modelList={filteredDiffusersModels}
|
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
|
||||||
key="diffusers"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{/* Checkpoints List */}
|
|
||||||
{isLoadingCheckpointModels && <FetchingModelsLoader loadingMessage="Loading Checkpoints..." />}
|
|
||||||
{['all', 'checkpoint'].includes(modelFormatFilter) &&
|
|
||||||
!isLoadingCheckpointModels &&
|
|
||||||
filteredCheckpointModels.length > 0 && (
|
|
||||||
<ModelListWrapper
|
|
||||||
title="Checkpoints"
|
|
||||||
modelList={filteredCheckpointModels}
|
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
|
||||||
key="checkpoints"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* LoRAs List */}
|
|
||||||
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
|
||||||
{['all', 'lora'].includes(modelFormatFilter) && !isLoadingLoraModels && filteredLoraModels.length > 0 && (
|
|
||||||
<ModelListWrapper
|
|
||||||
title="LoRAs"
|
|
||||||
modelList={filteredLoraModels}
|
|
||||||
selected={{ selectedModelId, setSelectedModelId }}
|
|
||||||
key="loras"
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ModelList);
|
|
||||||
|
|
||||||
const modelsFilter = <T extends MainModelConfig | LoRAConfig>(
|
|
||||||
data: EntityState<T, string> | undefined,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_format: ModelFormat | undefined,
|
|
||||||
nameFilter: string
|
|
||||||
) => {
|
|
||||||
const filteredModels: T[] = [];
|
|
||||||
forEach(data?.entities, (model) => {
|
|
||||||
if (!model) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
|
||||||
|
|
||||||
const matchesFormat = model_format === undefined || model.format === model_format;
|
|
||||||
const matchesType = model.type === model_type;
|
|
||||||
|
|
||||||
if (matchesFilter && matchesFormat && matchesType) {
|
|
||||||
filteredModels.push(model);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return filteredModels;
|
|
||||||
};
|
|
||||||
|
|
||||||
const StyledModelContainer = memo((props: PropsWithChildren) => {
|
|
||||||
return (
|
|
||||||
<Flex flexDirection="column" gap={4} borderRadius={4} p={4} bg="base.800">
|
|
||||||
{props.children}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
StyledModelContainer.displayName = 'StyledModelContainer';
|
|
||||||
|
|
||||||
type ModelListWrapperProps = {
|
|
||||||
title: string;
|
|
||||||
modelList: MainModelConfig[] | LoRAConfig[];
|
|
||||||
selected: ModelListProps;
|
|
||||||
};
|
|
||||||
|
|
||||||
const ModelListWrapper = memo((props: ModelListWrapperProps) => {
|
|
||||||
const { title, modelList, selected } = props;
|
|
||||||
return (
|
|
||||||
<StyledModelContainer>
|
|
||||||
<Flex gap={2} flexDir="column">
|
|
||||||
<Text variant="subtext" fontSize="sm">
|
|
||||||
{title}
|
|
||||||
</Text>
|
|
||||||
{modelList.map((model) => (
|
|
||||||
<ModelListItem
|
|
||||||
key={model.key}
|
|
||||||
model={model}
|
|
||||||
isSelected={selected.selectedModelId === model.key}
|
|
||||||
setSelectedModelId={selected.setSelectedModelId}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</Flex>
|
|
||||||
</StyledModelContainer>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
ModelListWrapper.displayName = 'ModelListWrapper';
|
|
||||||
|
|
||||||
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
|
||||||
return (
|
|
||||||
<StyledModelContainer>
|
|
||||||
<Flex justifyContent="center" alignItems="center" flexDirection="column" p={4} gap={8}>
|
|
||||||
<Spinner />
|
|
||||||
<Text variant="subtext">{loadingMessage ? loadingMessage : 'Fetching...'}</Text>
|
|
||||||
</Flex>
|
|
||||||
</StyledModelContainer>
|
|
||||||
);
|
|
||||||
});
|
|
||||||
|
|
||||||
FetchingModelsLoader.displayName = 'FetchingModelsLoader';
|
|
@ -1,111 +0,0 @@
|
|||||||
import {
|
|
||||||
Badge,
|
|
||||||
Button,
|
|
||||||
ConfirmationAlertDialog,
|
|
||||||
Flex,
|
|
||||||
IconButton,
|
|
||||||
Text,
|
|
||||||
Tooltip,
|
|
||||||
useDisclosure,
|
|
||||||
} from '@invoke-ai/ui-library';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { memo, useCallback } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
|
||||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
|
||||||
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
type ModelListItemProps = {
|
|
||||||
model: MainModelConfig | LoRAConfig;
|
|
||||||
isSelected: boolean;
|
|
||||||
setSelectedModelId: (v: string | undefined) => void;
|
|
||||||
};
|
|
||||||
|
|
||||||
const ModelListItem = (props: ModelListItemProps) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const [deleteModel] = useDeleteModelsMutation();
|
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
|
||||||
|
|
||||||
const { model, isSelected, setSelectedModelId } = props;
|
|
||||||
|
|
||||||
const handleSelectModel = useCallback(() => {
|
|
||||||
setSelectedModelId(model.key);
|
|
||||||
}, [model.key, setSelectedModelId]);
|
|
||||||
|
|
||||||
const handleModelDelete = useCallback(() => {
|
|
||||||
deleteModel({ key: model.key })
|
|
||||||
.unwrap()
|
|
||||||
.then((_) => {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
|
|
||||||
status: 'success',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
if (error) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
setSelectedModelId(undefined);
|
|
||||||
}, [deleteModel, model, setSelectedModelId, dispatch, t]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex gap={2} alignItems="center" w="full">
|
|
||||||
<Flex
|
|
||||||
as={Button}
|
|
||||||
isChecked={isSelected}
|
|
||||||
variant={isSelected ? 'solid' : 'ghost'}
|
|
||||||
justifyContent="start"
|
|
||||||
p={2}
|
|
||||||
borderRadius="base"
|
|
||||||
w="full"
|
|
||||||
alignItems="center"
|
|
||||||
onClick={handleSelectModel}
|
|
||||||
>
|
|
||||||
<Flex gap={4} alignItems="center">
|
|
||||||
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
|
|
||||||
{MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]}
|
|
||||||
</Badge>
|
|
||||||
<Tooltip label={model.description} placement="bottom">
|
|
||||||
<Text>{model.name}</Text>
|
|
||||||
</Tooltip>
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
<IconButton
|
|
||||||
onClick={onOpen}
|
|
||||||
icon={<PiTrashSimpleBold />}
|
|
||||||
aria-label={t('modelManager.deleteConfig')}
|
|
||||||
colorScheme="error"
|
|
||||||
/>
|
|
||||||
<ConfirmationAlertDialog
|
|
||||||
isOpen={isOpen}
|
|
||||||
onClose={onClose}
|
|
||||||
title={t('modelManager.deleteModel')}
|
|
||||||
acceptCallback={handleModelDelete}
|
|
||||||
acceptButtonText={t('modelManager.delete')}
|
|
||||||
>
|
|
||||||
<Flex rowGap={4} flexDirection="column">
|
|
||||||
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
|
|
||||||
<Text>{t('modelManager.deleteMsg2')}</Text>
|
|
||||||
</Flex>
|
|
||||||
</ConfirmationAlertDialog>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ModelListItem);
|
|
@ -1,14 +0,0 @@
|
|||||||
import { Flex } from '@invoke-ai/ui-library';
|
|
||||||
import { memo } from 'react';
|
|
||||||
|
|
||||||
import SyncModels from './ModelManagerSettingsPanel/SyncModels';
|
|
||||||
|
|
||||||
const ModelManagerSettingsPanel = () => {
|
|
||||||
return (
|
|
||||||
<Flex>
|
|
||||||
<SyncModels />
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(ModelManagerSettingsPanel);
|
|
@ -1,22 +0,0 @@
|
|||||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
|
||||||
import { SyncModelsButton } from 'features/modelManager/components/SyncModels/SyncModelsButton';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
|
|
||||||
const SyncModels = () => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex w="full" p={4} borderRadius={4} gap={4} justifyContent="space-between" alignItems="center" bg="base.800">
|
|
||||||
<Flex flexDirection="column" gap={2}>
|
|
||||||
<Text fontWeight="semibold">{t('modelManager.syncModels')}</Text>
|
|
||||||
<Text fontSize="sm" variant="subtext">
|
|
||||||
{t('modelManager.syncModelsDesc')}
|
|
||||||
</Text>
|
|
||||||
</Flex>
|
|
||||||
<SyncModelsButton />
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(SyncModels);
|
|
@ -1,36 +0,0 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
|
||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
|
||||||
import { useController } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
|
||||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
|
||||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
|
||||||
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
|
||||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
|
||||||
];
|
|
||||||
|
|
||||||
const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const { field } = useController(props);
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
|
||||||
(v) => {
|
|
||||||
field.onChange(v?.value);
|
|
||||||
},
|
|
||||||
[field]
|
|
||||||
);
|
|
||||||
return (
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
|
||||||
<Combobox value={value} options={options} onChange={onChange} />
|
|
||||||
</FormControl>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default typedMemo(BaseModelSelect);
|
|
@ -1,32 +0,0 @@
|
|||||||
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
|
||||||
import { useController, type UseControllerProps } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
|
|
||||||
import type { CheckpointModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const sx: ChakraProps['sx'] = { w: 'full' };
|
|
||||||
|
|
||||||
const CheckpointConfigsSelect = (props: UseControllerProps<CheckpointModelConfig>) => {
|
|
||||||
const { data } = useGetCheckpointConfigsQuery();
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const options = useMemo<ComboboxOption[]>(() => (data ? data.map((i) => ({ label: i, value: i })) : []), [data]);
|
|
||||||
const { field } = useController(props);
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value, options]);
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
|
||||||
(v) => {
|
|
||||||
field.onChange(v?.value);
|
|
||||||
},
|
|
||||||
[field]
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.configFile')}</FormLabel>
|
|
||||||
<Combobox placeholder="Select A Config File" value={value} options={options} onChange={onChange} sx={sx} />
|
|
||||||
</FormControl>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(CheckpointConfigsSelect);
|
|
@ -1,34 +0,0 @@
|
|||||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|
||||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
|
||||||
import { useCallback, useMemo } from 'react';
|
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
|
||||||
import { useController } from 'react-hook-form';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import type { CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = [
|
|
||||||
{ value: 'normal', label: 'Normal' },
|
|
||||||
{ value: 'inpaint', label: 'Inpaint' },
|
|
||||||
{ value: 'depth', label: 'Depth' },
|
|
||||||
];
|
|
||||||
|
|
||||||
const ModelVariantSelect = <T extends CheckpointModelConfig | DiffusersModelConfig>(props: UseControllerProps<T>) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const { field } = useController(props);
|
|
||||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
|
||||||
const onChange = useCallback<ComboboxOnChange>(
|
|
||||||
(v) => {
|
|
||||||
field.onChange(v?.value);
|
|
||||||
},
|
|
||||||
[field]
|
|
||||||
);
|
|
||||||
return (
|
|
||||||
<FormControl>
|
|
||||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
|
||||||
<Combobox value={value} options={options} onChange={onChange} />
|
|
||||||
</FormControl>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default typedMemo(ModelVariantSelect);
|
|
@ -1,7 +1,7 @@
|
|||||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||||
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type {
|
import type {
|
||||||
SDXLRefinerModelFieldInputInstance,
|
SDXLRefinerModelFieldInputInstance,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { memo, useCallback } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||||
import { pick } from 'lodash-es';
|
import { pick } from 'lodash-es';
|
||||||
|
@ -16,7 +16,7 @@ import { EMPTY_ARRAY } from 'app/store/util';
|
|||||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||||
import LoRASelect from 'features/lora/components/LoRASelect';
|
import LoRASelect from 'features/lora/components/LoRASelect';
|
||||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||||
|
@ -10,7 +10,6 @@ import type {
|
|||||||
IPAdapterModelConfig,
|
IPAdapterModelConfig,
|
||||||
LoRAModelConfig,
|
LoRAModelConfig,
|
||||||
MainModelConfig,
|
MainModelConfig,
|
||||||
MergeModelConfig,
|
|
||||||
T2IAdapterModelConfig,
|
T2IAdapterModelConfig,
|
||||||
TextualInversionModelConfig,
|
TextualInversionModelConfig,
|
||||||
VAEModelConfig,
|
VAEModelConfig,
|
||||||
@ -38,22 +37,9 @@ type DeleteMainModelArg = {
|
|||||||
|
|
||||||
type DeleteMainModelResponse = void;
|
type DeleteMainModelResponse = void;
|
||||||
|
|
||||||
type ConvertMainModelArg = {
|
|
||||||
base_model: BaseModelType;
|
|
||||||
model_name: string;
|
|
||||||
convert_dest_directory?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
type ConvertMainModelResponse =
|
type ConvertMainModelResponse =
|
||||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type MergeMainModelArg = {
|
|
||||||
base_model: BaseModelType;
|
|
||||||
body: MergeModelConfig;
|
|
||||||
};
|
|
||||||
|
|
||||||
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
|
|
||||||
|
|
||||||
type ImportMainModelArg = {
|
type ImportMainModelArg = {
|
||||||
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
|
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
|
||||||
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
|
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
|
||||||
@ -72,20 +58,11 @@ type DeleteImportModelsResponse =
|
|||||||
type PruneModelImportsResponse =
|
type PruneModelImportsResponse =
|
||||||
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type ImportAdvancedModelArg = {
|
|
||||||
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
|
|
||||||
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
|
|
||||||
};
|
|
||||||
|
|
||||||
type ImportAdvancedModelResponse = paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
|
|
||||||
|
|
||||||
export type ScanFolderResponse =
|
export type ScanFolderResponse =
|
||||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||||
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
||||||
|
|
||||||
type CheckpointConfigsResponse =
|
|
||||||
paths['/api/v2/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
|
|
||||||
|
|
||||||
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
||||||
selectId: (entity) => entity.key,
|
selectId: (entity) => entity.key,
|
||||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||||
@ -199,16 +176,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model', 'ModelImports'],
|
invalidatesTags: ['Model', 'ModelImports'],
|
||||||
}),
|
}),
|
||||||
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
|
|
||||||
query: ({ source, config}) => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl('install'),
|
|
||||||
method: 'POST',
|
|
||||||
body: { source, config },
|
|
||||||
};
|
|
||||||
},
|
|
||||||
invalidatesTags: ['Model', 'ModelImports'],
|
|
||||||
}),
|
|
||||||
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||||
query: ({ key }) => {
|
query: ({ key }) => {
|
||||||
return {
|
return {
|
||||||
@ -218,25 +185,14 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
convertMainModels: build.mutation<ConvertMainModelResponse, ConvertMainModelArg>({
|
convertMainModels: build.mutation<ConvertMainModelResponse, string>({
|
||||||
query: ({ base_model, model_name, convert_dest_directory }) => {
|
query: (key) => {
|
||||||
return {
|
return {
|
||||||
url: buildModelsUrl(`convert/${base_model}/main/${model_name}`),
|
url: buildModelsUrl(`convert/${key}`),
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
params: { convert_dest_directory },
|
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['ModelConfig'],
|
||||||
}),
|
|
||||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
|
||||||
query: ({ base_model, body }) => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl(`merge/${base_model}`),
|
|
||||||
method: 'PUT',
|
|
||||||
body: body,
|
|
||||||
};
|
|
||||||
},
|
|
||||||
invalidatesTags: ['Model'],
|
|
||||||
}),
|
}),
|
||||||
getModelConfig: build.query<GetModelConfigResponse, string>({
|
getModelConfig: build.query<GetModelConfigResponse, string>({
|
||||||
query: (key) => buildModelsUrl(`i/${key}`),
|
query: (key) => buildModelsUrl(`i/${key}`),
|
||||||
@ -323,13 +279,6 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['ModelImports'],
|
invalidatesTags: ['ModelImports'],
|
||||||
}),
|
}),
|
||||||
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
|
|
||||||
query: () => {
|
|
||||||
return {
|
|
||||||
url: buildModelsUrl(`ckpt_confs`),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -345,13 +294,10 @@ export const {
|
|||||||
useDeleteModelsMutation,
|
useDeleteModelsMutation,
|
||||||
useUpdateModelsMutation,
|
useUpdateModelsMutation,
|
||||||
useImportMainModelsMutation,
|
useImportMainModelsMutation,
|
||||||
useImportAdvancedModelMutation,
|
|
||||||
useConvertMainModelsMutation,
|
useConvertMainModelsMutation,
|
||||||
useMergeMainModelsMutation,
|
|
||||||
useSyncModelsMutation,
|
useSyncModelsMutation,
|
||||||
useScanModelsQuery,
|
useScanModelsQuery,
|
||||||
useLazyScanModelsQuery,
|
useLazyScanModelsQuery,
|
||||||
useGetCheckpointConfigsQuery,
|
|
||||||
useGetModelImportsQuery,
|
useGetModelImportsQuery,
|
||||||
useGetModelMetadataQuery,
|
useGetModelMetadataQuery,
|
||||||
useDeleteModelImportMutation,
|
useDeleteModelImportMutation,
|
||||||
|
Loading…
Reference in New Issue
Block a user