mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up old model manager components and endpoints
This commit is contained in:
parent
7b1b6d3235
commit
cfcb68696c
@ -15,8 +15,7 @@ import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynam
|
||||
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
|
||||
import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice';
|
||||
import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice';
|
||||
import { modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
|
||||
import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
|
||||
@ -55,7 +54,6 @@ const allReducers = {
|
||||
[deleteImageModalSlice.name]: deleteImageModalSlice.reducer,
|
||||
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
|
||||
[loraSlice.name]: loraSlice.reducer,
|
||||
[modelManagerSlice.name]: modelManagerSlice.reducer,
|
||||
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
|
||||
[sdxlSlice.name]: sdxlSlice.reducer,
|
||||
[queueSlice.name]: queueSlice.reducer,
|
||||
@ -103,7 +101,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
|
||||
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
|
||||
[sdxlPersistConfig.name]: sdxlPersistConfig,
|
||||
[loraPersistConfig.name]: loraPersistConfig,
|
||||
[modelManagerPersistConfig.name]: modelManagerPersistConfig,
|
||||
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
|
||||
[hrfPersistConfig.name]: hrfPersistConfig,
|
||||
};
|
||||
|
||||
|
@ -1,32 +0,0 @@
|
||||
import type { ButtonProps } from '@invoke-ai/ui-library';
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
import { useSyncModels } from './useSyncModels';
|
||||
|
||||
export const SyncModelsButton = memo((props: Omit<ButtonProps, 'children'>) => {
|
||||
const { t } = useTranslation();
|
||||
const { syncModels, isLoading } = useSyncModels();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
if (!isSyncModelEnabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Button
|
||||
isLoading={isLoading}
|
||||
onClick={syncModels}
|
||||
leftIcon={<PiArrowsClockwiseBold />}
|
||||
minW="max-content"
|
||||
{...props}
|
||||
>
|
||||
{t('modelManager.syncModels')}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
|
||||
SyncModelsButton.displayName = 'SyncModelsButton';
|
@ -1,33 +0,0 @@
|
||||
import type { IconButtonProps } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
import { useSyncModels } from './useSyncModels';
|
||||
|
||||
export const SyncModelsIconButton = memo((props: Omit<IconButtonProps, 'aria-label'>) => {
|
||||
const { t } = useTranslation();
|
||||
const { syncModels, isLoading } = useSyncModels();
|
||||
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;
|
||||
|
||||
if (!isSyncModelEnabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
icon={<PiArrowsClockwiseBold />}
|
||||
tooltip={t('modelManager.syncModels')}
|
||||
aria-label={t('modelManager.syncModels')}
|
||||
isLoading={isLoading}
|
||||
onClick={syncModels}
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
SyncModelsIconButton.displayName = 'SyncModelsIconButton';
|
@ -1,40 +0,0 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSyncModelsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
export const useSyncModels = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const [_syncModels, { isLoading }] = useSyncModelsMutation();
|
||||
const syncModels = useCallback(() => {
|
||||
_syncModels()
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelsSynced')}`,
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelSyncFailed')}`,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [dispatch, _syncModels, t]);
|
||||
|
||||
return { syncModels, isLoading };
|
||||
};
|
@ -1,47 +0,0 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
|
||||
type ModelManagerState = {
|
||||
_version: 1;
|
||||
searchFolder: string | null;
|
||||
advancedAddScanModel: string | null;
|
||||
};
|
||||
|
||||
export const initialModelManagerState: ModelManagerState = {
|
||||
_version: 1,
|
||||
searchFolder: null,
|
||||
advancedAddScanModel: null,
|
||||
};
|
||||
|
||||
export const modelManagerSlice = createSlice({
|
||||
name: 'modelmanager',
|
||||
initialState: initialModelManagerState,
|
||||
reducers: {
|
||||
setSearchFolder: (state, action: PayloadAction<string | null>) => {
|
||||
state.searchFolder = action.payload;
|
||||
},
|
||||
setAdvancedAddScanModel: (state, action: PayloadAction<string | null>) => {
|
||||
state.advancedAddScanModel = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { setSearchFolder, setAdvancedAddScanModel } = modelManagerSlice.actions;
|
||||
|
||||
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
export const migrateModelManagerState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
export const modelManagerPersistConfig: PersistConfig<ModelManagerState> = {
|
||||
name: modelManagerSlice.name,
|
||||
initialState: initialModelManagerState,
|
||||
migrate: migrateModelManagerState,
|
||||
persistDenylist: [],
|
||||
};
|
@ -1,35 +0,0 @@
|
||||
import { Button, ButtonGroup, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetModelImportsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import AdvancedAddModels from './AdvancedAddModels';
|
||||
import SimpleAddModels from './SimpleAddModels';
|
||||
|
||||
const AddModels = () => {
|
||||
const { t } = useTranslation();
|
||||
const [addModelMode, setAddModelMode] = useState<'simple' | 'advanced'>('simple');
|
||||
const handleAddModelSimple = useCallback(() => setAddModelMode('simple'), []);
|
||||
const handleAddModelAdvanced = useCallback(() => setAddModelMode('advanced'), []);
|
||||
const { data } = useGetModelImportsQuery();
|
||||
console.log({ data });
|
||||
return (
|
||||
<Flex flexDirection="column" width="100%" overflow="scroll" maxHeight={window.innerHeight - 250} gap={4}>
|
||||
<ButtonGroup>
|
||||
<Button size="sm" isChecked={addModelMode === 'simple'} onClick={handleAddModelSimple}>
|
||||
{t('common.simple')}
|
||||
</Button>
|
||||
<Button size="sm" isChecked={addModelMode === 'advanced'} onClick={handleAddModelAdvanced}>
|
||||
{t('common.advanced')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
<Flex p={4} borderRadius={4} bg="base.800">
|
||||
{addModelMode === 'simple' && <SimpleAddModels />}
|
||||
{addModelMode === 'advanced' && <AdvancedAddModels />}
|
||||
</Flex>
|
||||
<Flex>{data?.map((model) => <Text key={model.id}>{model.status}</Text>)}</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AddModels);
|
@ -1,168 +0,0 @@
|
||||
import { Button, Checkbox, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
|
||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
||||
import CheckpointConfigsSelect from 'features/modelManager/subpanels/shared/CheckpointConfigsSelect';
|
||||
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { CSSProperties, FocusEventHandler } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
import { getModelName } from './util';
|
||||
|
||||
type AdvancedAddCheckpointProps = {
|
||||
model_path?: string;
|
||||
};
|
||||
|
||||
const AdvancedAddCheckpoint = (props: AdvancedAddCheckpointProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { model_path } = props;
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
getValues,
|
||||
setValue,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<CheckpointModelConfig>({
|
||||
defaultValues: {
|
||||
model_name: model_path ? getModelName(model_path) : '',
|
||||
base_model: 'sd-1',
|
||||
model_type: 'main',
|
||||
path: model_path ? model_path : '',
|
||||
description: '',
|
||||
model_format: 'checkpoint',
|
||||
error: undefined,
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
config: 'configs\\stable-diffusion\\v1-inference.yaml',
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const [addMainModel] = useAddMainModelsMutation();
|
||||
|
||||
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
|
||||
(values) => {
|
||||
addMainModel({
|
||||
body: values,
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelAdded', {
|
||||
modelName: values.model_name,
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
|
||||
// Close Advanced Panel in Scan Models tab
|
||||
if (model_path) {
|
||||
dispatch(setAdvancedAddScanModel(null));
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[addMainModel, dispatch, model_path, reset, t]
|
||||
);
|
||||
|
||||
const onBlur: FocusEventHandler<HTMLInputElement> = useCallback(
|
||||
(e) => {
|
||||
if (getValues().model_name === '') {
|
||||
const modelName = getModelName(e.currentTarget.value);
|
||||
if (modelName) {
|
||||
setValue('model_name', modelName as string);
|
||||
}
|
||||
}
|
||||
},
|
||||
[getValues, setValue]
|
||||
);
|
||||
|
||||
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
<FormControl isInvalid={Boolean(errors.model_name)}>
|
||||
<FormLabel>{t('modelManager.model')}</FormLabel>
|
||||
<Input
|
||||
{...register('model_name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<BaseModelSelect<CheckpointModelConfig> control={control} name="base_model" />
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
onBlur,
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
|
||||
<Flex flexDirection="column" width="100%" gap={2}>
|
||||
{!useCustomConfig ? (
|
||||
<CheckpointConfigsSelect control={control} name="config" />
|
||||
) : (
|
||||
<FormControl isRequired>
|
||||
<FormLabel>{t('modelManager.config')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
|
||||
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
|
||||
</FormControl>
|
||||
<Button mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
};
|
||||
|
||||
const formStyles: CSSProperties = {
|
||||
width: '100%',
|
||||
};
|
||||
|
||||
export default memo(AdvancedAddCheckpoint);
|
@ -1,148 +0,0 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
|
||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
||||
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { CSSProperties, FocusEventHandler } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useAddMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig } from 'services/api/types';
|
||||
|
||||
import { getModelName } from './util';
|
||||
|
||||
type AdvancedAddDiffusersProps = {
|
||||
model_path?: string;
|
||||
};
|
||||
|
||||
const AdvancedAddDiffusers = (props: AdvancedAddDiffusersProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { model_path } = props;
|
||||
|
||||
const [addMainModel] = useAddMainModelsMutation();
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
getValues,
|
||||
setValue,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<DiffusersModelConfig>({
|
||||
defaultValues: {
|
||||
model_name: model_path ? getModelName(model_path, false) : '',
|
||||
base_model: 'sd-1',
|
||||
model_type: 'main',
|
||||
path: model_path ? model_path : '',
|
||||
description: '',
|
||||
model_format: 'diffusers',
|
||||
error: undefined,
|
||||
vae: '',
|
||||
variant: 'normal',
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
|
||||
(values) => {
|
||||
addMainModel({
|
||||
body: values,
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelAdded', {
|
||||
modelName: values.model_name,
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
reset();
|
||||
// Close Advanced Panel in Scan Models tab
|
||||
if (model_path) {
|
||||
dispatch(setAdvancedAddScanModel(null));
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[addMainModel, dispatch, model_path, reset, t]
|
||||
);
|
||||
|
||||
const onBlur: FocusEventHandler<HTMLInputElement> = useCallback(
|
||||
(e) => {
|
||||
if (getValues().model_name === '') {
|
||||
const modelName = getModelName(e.currentTarget.value, false);
|
||||
if (modelName) {
|
||||
setValue('model_name', modelName as string);
|
||||
}
|
||||
}
|
||||
},
|
||||
[getValues, setValue]
|
||||
);
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit(onSubmit)} style={formStyles}>
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
<FormControl isInvalid={Boolean(errors.model_name)}>
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('model_name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.model_name?.message && <FormErrorMessage>{errors.model_name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<BaseModelSelect<DiffusersModelConfig> control={control} name="base_model" />
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
onBlur,
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
|
||||
|
||||
<Button mt={2} type="submit">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
};
|
||||
|
||||
const formStyles: CSSProperties = {
|
||||
width: '100%',
|
||||
};
|
||||
|
||||
export default memo(AdvancedAddDiffusers);
|
@ -1,50 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
|
||||
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
|
||||
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
|
||||
|
||||
export const zManualAddMode = z.enum(['diffusers', 'checkpoint']);
|
||||
export type ManualAddMode = z.infer<typeof zManualAddMode>;
|
||||
export const isManualAddMode = (v: unknown): v is ManualAddMode => zManualAddMode.safeParse(v).success;
|
||||
|
||||
const AdvancedAddModels = () => {
|
||||
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
|
||||
|
||||
const { t } = useTranslation();
|
||||
const handleChange: ComboboxOnChange = useCallback((v) => {
|
||||
if (!isManualAddMode(v?.value)) {
|
||||
return;
|
||||
}
|
||||
setAdvancedAddMode(v.value);
|
||||
}, []);
|
||||
|
||||
const options: ComboboxOption[] = useMemo(
|
||||
() => [
|
||||
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
|
||||
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4} width="100%">
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={handleChange} />
|
||||
</FormControl>
|
||||
|
||||
<Flex p={4} borderRadius={4} bg="base.850">
|
||||
{advancedAddMode === 'diffusers' && <AdvancedAddDiffusers />}
|
||||
{advancedAddMode === 'checkpoint' && <AdvancedAddCheckpoint />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(AdvancedAddModels);
|
@ -1,176 +0,0 @@
|
||||
import { Button, Flex, FormControl, FormLabel, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { difference, forEach, intersection, map, values } from 'lodash-es';
|
||||
import type { ChangeEvent, MouseEvent } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import type { SearchFolderResponse } from 'services/api/endpoints/models';
|
||||
import {
|
||||
useGetMainModelsQuery,
|
||||
useGetModelsInFolderQuery,
|
||||
useImportMainModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
const FoundModelsList = () => {
|
||||
const searchFolder = useAppSelector((s) => s.modelmanager.searchFolder);
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
|
||||
// Get paths of models that are already installed
|
||||
const { data: installedModels } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
// Get all model paths from a given directory
|
||||
const { foundModels, alreadyInstalled, filteredModels } = useGetModelsInFolderQuery(
|
||||
{
|
||||
search_path: searchFolder ? searchFolder : '',
|
||||
},
|
||||
{
|
||||
selectFromResult: ({ data }) => {
|
||||
const installedModelValues = values(installedModels?.entities);
|
||||
const installedModelPaths = map(installedModelValues, 'path');
|
||||
// Only select models those that aren't already installed to Invoke
|
||||
const notInstalledModels = difference(data, installedModelPaths);
|
||||
const alreadyInstalled = intersection(data, installedModelPaths);
|
||||
return {
|
||||
foundModels: data,
|
||||
alreadyInstalled: foundModelsFilter(alreadyInstalled, nameFilter),
|
||||
filteredModels: foundModelsFilter(notInstalledModels, nameFilter),
|
||||
};
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const quickAddHandler = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
const model_name = e.currentTarget.id.split('\\').splice(-1)[0];
|
||||
importMainModel({
|
||||
body: {
|
||||
location: e.currentTarget.id,
|
||||
},
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `Added Model: ${model_name}`,
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
},
|
||||
[dispatch, importMainModel, t]
|
||||
);
|
||||
|
||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setNameFilter(e.target.value);
|
||||
}, []);
|
||||
|
||||
const handleClickSetAdvanced = useCallback((model: string) => dispatch(setAdvancedAddScanModel(model)), [dispatch]);
|
||||
|
||||
const renderModels = ({ models, showActions = true }: { models: string[]; showActions?: boolean }) => {
|
||||
return models.map((model) => {
|
||||
return (
|
||||
<Flex key={model} p={4} gap={4} alignItems="center" borderRadius={4} bg="base.800">
|
||||
<Flex w="full" minW="25%" flexDir="column">
|
||||
<Text fontWeight="semibold">{model.split('\\').slice(-1)[0]}</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{model}
|
||||
</Text>
|
||||
</Flex>
|
||||
{showActions ? (
|
||||
<Flex gap={2}>
|
||||
<Button id={model} onClick={quickAddHandler} isLoading={isLoading}>
|
||||
{t('modelManager.quickAdd')}
|
||||
</Button>
|
||||
<Button onClick={handleClickSetAdvanced.bind(null, model)} isLoading={isLoading}>
|
||||
{t('modelManager.advanced')}
|
||||
</Button>
|
||||
</Flex>
|
||||
) : (
|
||||
<Text fontWeight="semibold" p={2} borderRadius={4} color="invokeBlue.100" bg="invokeBlue.600">
|
||||
{t('common.installed')}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
const renderFoundModels = () => {
|
||||
if (!searchFolder) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!foundModels || foundModels.length === 0) {
|
||||
return (
|
||||
<Flex w="full" h="full" justifyContent="center" alignItems="center" height={96} userSelect="none" bg="base.900">
|
||||
<Text variant="subtext">{t('modelManager.noModels')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" gap={2} w="100%" minW="50%">
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.search')}</FormLabel>
|
||||
<Input onChange={handleSearchFilter} />
|
||||
</FormControl>
|
||||
<Flex p={2} gap={2}>
|
||||
<Text fontWeight="semibold">
|
||||
{t('modelManager.modelsFound')}: {foundModels.length}
|
||||
</Text>
|
||||
<Text fontWeight="semibold" color="invokeBlue.200">
|
||||
{t('common.notInstalled')}: {filteredModels.length}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<ScrollableContent>
|
||||
<Flex gap={2} flexDirection="column">
|
||||
{renderModels({ models: filteredModels })}
|
||||
{renderModels({ models: alreadyInstalled, showActions: false })}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
return renderFoundModels();
|
||||
};
|
||||
|
||||
const foundModelsFilter = (data: SearchFolderResponse | undefined, nameFilter: string) => {
|
||||
const filteredModels: SearchFolderResponse = [];
|
||||
forEach(data, (model) => {
|
||||
if (!model) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (model.includes(nameFilter)) {
|
||||
filteredModels.push(model);
|
||||
}
|
||||
});
|
||||
return filteredModels;
|
||||
};
|
||||
|
||||
export default memo(FoundModelsList);
|
@ -1,95 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Box, Combobox, Flex, FormControl, FormLabel, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setAdvancedAddScanModel } from 'features/modelManager/store/modelManagerSlice';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
|
||||
import AdvancedAddCheckpoint from './AdvancedAddCheckpoint';
|
||||
import AdvancedAddDiffusers from './AdvancedAddDiffusers';
|
||||
import type { ManualAddMode } from './AdvancedAddModels';
|
||||
import { isManualAddMode } from './AdvancedAddModels';
|
||||
|
||||
const ScanAdvancedAddModels = () => {
|
||||
const advancedAddScanModel = useAppSelector((s) => s.modelmanager.advancedAddScanModel);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const options: ComboboxOption[] = useMemo(
|
||||
() => [
|
||||
{ label: t('modelManager.diffusersModels'), value: 'diffusers' },
|
||||
{ label: t('modelManager.checkpointOrSafetensors'), value: 'checkpoint' },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const [advancedAddMode, setAdvancedAddMode] = useState<ManualAddMode>('diffusers');
|
||||
|
||||
const [isCheckpoint, setIsCheckpoint] = useState<boolean>(true);
|
||||
|
||||
useEffect(() => {
|
||||
advancedAddScanModel && ['.ckpt', '.safetensors', '.pth', '.pt'].some((ext) => advancedAddScanModel.endsWith(ext))
|
||||
? setAdvancedAddMode('checkpoint')
|
||||
: setAdvancedAddMode('diffusers');
|
||||
}, [advancedAddScanModel, setAdvancedAddMode, isCheckpoint]);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleClickSetAdvanced = useCallback(() => dispatch(setAdvancedAddScanModel(null)), [dispatch]);
|
||||
|
||||
const handleChangeAddMode = useCallback<ComboboxOnChange>((v) => {
|
||||
if (!isManualAddMode(v?.value)) {
|
||||
return;
|
||||
}
|
||||
setAdvancedAddMode(v.value);
|
||||
if (v.value === 'checkpoint') {
|
||||
setIsCheckpoint(true);
|
||||
} else {
|
||||
setIsCheckpoint(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === advancedAddMode), [options, advancedAddMode]);
|
||||
|
||||
if (!advancedAddScanModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
minWidth="40%"
|
||||
maxHeight="calc(100vh - 300px)"
|
||||
overflow="scroll"
|
||||
p={4}
|
||||
gap={4}
|
||||
borderRadius={4}
|
||||
bg="base.800"
|
||||
>
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Text size="xl" fontWeight="semibold">
|
||||
{isCheckpoint || advancedAddMode === 'checkpoint' ? 'Add Checkpoint Model' : 'Add Diffusers Model'}
|
||||
</Text>
|
||||
<IconButton
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('modelManager.closeAdvanced')}
|
||||
onClick={handleClickSetAdvanced}
|
||||
size="sm"
|
||||
/>
|
||||
</Flex>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={handleChangeAddMode} />
|
||||
</FormControl>
|
||||
{isCheckpoint ? (
|
||||
<AdvancedAddCheckpoint key={advancedAddScanModel} model_path={advancedAddScanModel} />
|
||||
) : (
|
||||
<AdvancedAddDiffusers key={advancedAddScanModel} model_path={advancedAddScanModel} />
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ScanAdvancedAddModels);
|
@ -1,22 +0,0 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
|
||||
import FoundModelsList from './FoundModelsList';
|
||||
import ScanAdvancedAddModels from './ScanAdvancedAddModels';
|
||||
import SearchFolderForm from './SearchFolderForm';
|
||||
|
||||
const ScanModels = () => {
|
||||
return (
|
||||
<Flex flexDirection="column" w="100%" h="full" gap={4}>
|
||||
<SearchFolderForm />
|
||||
<Flex gap={4}>
|
||||
<Flex overflow="scroll" gap={4} w="100%" h="full">
|
||||
<FoundModelsList />
|
||||
</Flex>
|
||||
<ScanAdvancedAddModels />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ScanModels);
|
@ -1,103 +0,0 @@
|
||||
import { Flex, IconButton, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { useForm } from '@mantine/form';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setAdvancedAddScanModel, setSearchFolder } from 'features/modelManager/store/modelManagerSlice';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowsCounterClockwiseBold, PiMagnifyingGlassBold, PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useGetModelsInFolderQuery } from 'services/api/endpoints/models';
|
||||
|
||||
type SearchFolderForm = {
|
||||
folder: string;
|
||||
};
|
||||
|
||||
function SearchFolderForm() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const searchFolder = useAppSelector((s) => s.modelmanager.searchFolder);
|
||||
|
||||
const { refetch: refetchFoundModels } = useGetModelsInFolderQuery({
|
||||
search_path: searchFolder ? searchFolder : '',
|
||||
});
|
||||
|
||||
const searchFolderForm = useForm<SearchFolderForm>({
|
||||
initialValues: {
|
||||
folder: '',
|
||||
},
|
||||
});
|
||||
|
||||
const searchFolderFormSubmitHandler = useCallback(
|
||||
(values: SearchFolderForm) => {
|
||||
dispatch(setSearchFolder(values.folder));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const scanAgainHandler = useCallback(() => {
|
||||
refetchFoundModels();
|
||||
}, [refetchFoundModels]);
|
||||
|
||||
const handleClickClearCheckpointFolder = useCallback(() => {
|
||||
dispatch(setSearchFolder(null));
|
||||
dispatch(setAdvancedAddScanModel(null));
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<form onSubmit={searchFolderForm.onSubmit((values) => searchFolderFormSubmitHandler(values))} style={formStyles}>
|
||||
<Flex w="100%" gap={2} borderRadius={4} alignItems="center">
|
||||
<Flex w="100%" alignItems="center" gap={4} minH={12}>
|
||||
<Text fontSize="sm" fontWeight="semibold" color="base.300" minW="max-content">
|
||||
{t('common.folder')}
|
||||
</Text>
|
||||
{!searchFolder ? (
|
||||
<Input w="100%" size="md" {...searchFolderForm.getInputProps('folder')} />
|
||||
) : (
|
||||
<Flex w="100%" p={2} px={4} bg="base.700" borderRadius={4} fontSize="sm" fontWeight="bold">
|
||||
{searchFolder}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
<Flex gap={2}>
|
||||
{!searchFolder ? (
|
||||
<IconButton
|
||||
aria-label={t('modelManager.findModels')}
|
||||
tooltip={t('modelManager.findModels')}
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
fontSize={18}
|
||||
size="sm"
|
||||
type="submit"
|
||||
/>
|
||||
) : (
|
||||
<IconButton
|
||||
aria-label={t('modelManager.scanAgain')}
|
||||
tooltip={t('modelManager.scanAgain')}
|
||||
icon={<PiArrowsCounterClockwiseBold />}
|
||||
onClick={scanAgainHandler}
|
||||
fontSize={18}
|
||||
size="sm"
|
||||
/>
|
||||
)}
|
||||
|
||||
<IconButton
|
||||
aria-label={t('modelManager.clearCheckpointFolder')}
|
||||
tooltip={t('modelManager.clearCheckpointFolder')}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
size="sm"
|
||||
onClick={handleClickClearCheckpointFolder}
|
||||
isDisabled={!searchFolder}
|
||||
colorScheme="red"
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
|
||||
export default memo(SearchFolderForm);
|
||||
|
||||
const formStyles: CSSProperties = {
|
||||
width: '100%',
|
||||
};
|
@ -1,92 +0,0 @@
|
||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Button, Combobox, Flex, FormControl, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import { useForm } from '@mantine/form';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useImportMainModelsMutation } from 'services/api/endpoints/models';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ label: 'None', value: 'none' },
|
||||
{ label: 'v_prediction', value: 'v_prediction' },
|
||||
{ label: 'epsilon', value: 'epsilon' },
|
||||
{ label: 'sample', value: 'sample' },
|
||||
];
|
||||
|
||||
type ExtendedImportModelConfig = {
|
||||
location: string;
|
||||
prediction_type?: 'v_prediction' | 'epsilon' | 'sample' | 'none' | undefined;
|
||||
};
|
||||
|
||||
const SimpleAddModels = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [importMainModel, { isLoading }] = useImportMainModelsMutation();
|
||||
|
||||
const addModelForm = useForm<ExtendedImportModelConfig>({
|
||||
initialValues: {
|
||||
location: '',
|
||||
prediction_type: undefined,
|
||||
},
|
||||
});
|
||||
|
||||
const handleAddModelSubmit = (values: ExtendedImportModelConfig) => {
|
||||
const importModelResponseBody = {
|
||||
config: values.prediction_type === 'none' ? undefined : values.prediction_type,
|
||||
};
|
||||
|
||||
importMainModel({ source: values.location, config: importModelResponseBody })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
addModelForm.reset();
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={addModelForm.onSubmit((v) => handleAddModelSubmit(v))} style={formStyles}>
|
||||
<Flex flexDirection="column" width="100%" gap={4}>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
<Input placeholder={t('modelManager.simpleModelDesc')} w="100%" {...addModelForm.getInputProps('location')} />
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<Combobox options={options} defaultValue={options[0]} {...addModelForm.getInputProps('prediction_type')} />
|
||||
</FormControl>
|
||||
<Button type="submit" isLoading={isLoading}>
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</form>
|
||||
);
|
||||
};
|
||||
|
||||
const formStyles: CSSProperties = {
|
||||
width: '100%',
|
||||
};
|
||||
|
||||
export default memo(SimpleAddModels);
|
@ -1,15 +0,0 @@
|
||||
export function getModelName(filepath: string, isCheckpoint: boolean = true) {
|
||||
let regex;
|
||||
if (isCheckpoint) {
|
||||
regex = new RegExp('[^\\\\/]+(?=\\.)');
|
||||
} else {
|
||||
regex = new RegExp('[^\\\\/]+(?=[\\\\/]?$)');
|
||||
}
|
||||
|
||||
const match = filepath.match(regex);
|
||||
if (match) {
|
||||
return match[0];
|
||||
} else {
|
||||
return '';
|
||||
}
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import AddModels from './AddModelsPanel/AddModels';
|
||||
import ScanModels from './AddModelsPanel/ScanModels';
|
||||
|
||||
type AddModelTabs = 'add' | 'scan';
|
||||
|
||||
const ImportModelsPanel = () => {
|
||||
const [addModelTab, setAddModelTab] = useState<AddModelTabs>('add');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleClickAddTab = useCallback(() => setAddModelTab('add'), []);
|
||||
const handleClickScanTab = useCallback(() => setAddModelTab('scan'), []);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4} h="full">
|
||||
<ButtonGroup>
|
||||
<Button onClick={handleClickAddTab} isChecked={addModelTab === 'add'} size="sm" width="100%">
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
<Button onClick={handleClickScanTab} isChecked={addModelTab === 'scan'} size="sm" width="100%">
|
||||
{t('modelManager.scanForModels')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
|
||||
{addModelTab === 'add' && <AddModels />}
|
||||
{addModelTab === 'scan' && <ScanModels />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ImportModelsPanel);
|
@ -1,352 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Button,
|
||||
Checkbox,
|
||||
Combobox,
|
||||
CompositeNumberInput,
|
||||
CompositeSlider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormHelperText,
|
||||
FormLabel,
|
||||
Input,
|
||||
Radio,
|
||||
RadioGroup,
|
||||
Text,
|
||||
Tooltip,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { pickBy } from 'lodash-es';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetMainModelsQuery, useMergeMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { BaseModelType, MergeModelConfig } from 'services/api/types';
|
||||
|
||||
const baseModelTypeSelectOptions: ComboboxOption[] = [
|
||||
{ label: 'Stable Diffusion 1', value: 'sd-1' },
|
||||
{ label: 'Stable Diffusion 2', value: 'sd-2' },
|
||||
];
|
||||
|
||||
type MergeInterpolationMethods = 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference';
|
||||
|
||||
const MergeModelsPanel = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data } = useGetMainModelsQuery(ALL_BASE_MODELS);
|
||||
|
||||
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
|
||||
|
||||
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
|
||||
const valueBaseModel = useMemo(() => baseModelTypeSelectOptions.find((o) => o.value === baseModel), [baseModel]);
|
||||
const sd1DiffusersModels = pickBy(
|
||||
data?.entities,
|
||||
(value, _) => value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
|
||||
);
|
||||
|
||||
const sd2DiffusersModels = pickBy(
|
||||
data?.entities,
|
||||
(value, _) => value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
|
||||
);
|
||||
|
||||
const modelsMap = useMemo(() => {
|
||||
return {
|
||||
'sd-1': sd1DiffusersModels,
|
||||
'sd-2': sd2DiffusersModels,
|
||||
};
|
||||
}, [sd1DiffusersModels, sd2DiffusersModels]);
|
||||
|
||||
const [modelOne, setModelOne] = useState<string | null>(
|
||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])?.[0] ?? null
|
||||
);
|
||||
const [modelTwo, setModelTwo] = useState<string | null>(
|
||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])?.[1] ?? null
|
||||
);
|
||||
const [modelThree, setModelThree] = useState<string | null>(null);
|
||||
|
||||
const [mergedModelName, setMergedModelName] = useState<string>('');
|
||||
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
|
||||
|
||||
const [modelMergeInterp, setModelMergeInterp] = useState<MergeInterpolationMethods>('weighted_sum');
|
||||
|
||||
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<'root' | 'custom'>('root');
|
||||
|
||||
const [modelMergeCustomSaveLoc, setModelMergeCustomSaveLoc] = useState<string>('');
|
||||
|
||||
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
|
||||
|
||||
const optionsModelOne = useMemo(
|
||||
() =>
|
||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])
|
||||
.filter((model) => model !== modelTwo && model !== modelThree)
|
||||
.map((model) => ({ label: model, value: model })),
|
||||
[modelsMap, baseModel, modelTwo, modelThree]
|
||||
);
|
||||
|
||||
const optionsModelTwo = useMemo(
|
||||
() =>
|
||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])
|
||||
.filter((model) => model !== modelOne && model !== modelThree)
|
||||
.map((model) => ({ label: model, value: model })),
|
||||
[modelsMap, baseModel, modelOne, modelThree]
|
||||
);
|
||||
|
||||
const optionsModelThree = useMemo(
|
||||
() =>
|
||||
Object.keys(modelsMap[baseModel as keyof typeof modelsMap])
|
||||
.filter((model) => model !== modelOne && model !== modelTwo)
|
||||
.map((model) => ({ label: model, value: model })),
|
||||
[modelsMap, baseModel, modelOne, modelTwo]
|
||||
);
|
||||
|
||||
const onChangeBaseModel = useCallback<ComboboxOnChange>((v) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
if (!(v.value === 'sd-1' || v.value === 'sd-2')) {
|
||||
return;
|
||||
}
|
||||
setBaseModel(v.value);
|
||||
setModelOne(null);
|
||||
setModelTwo(null);
|
||||
}, []);
|
||||
|
||||
const onChangeModelOne = useCallback<ComboboxOnChange>((v) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
setModelOne(v.value);
|
||||
}, []);
|
||||
const onChangeModelTwo = useCallback<ComboboxOnChange>((v) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
setModelTwo(v.value);
|
||||
}, []);
|
||||
const onChangeModelThree = useCallback<ComboboxOnChange>((v) => {
|
||||
if (!v) {
|
||||
setModelThree(null);
|
||||
setModelMergeInterp('add_difference');
|
||||
} else {
|
||||
setModelThree(v.value);
|
||||
setModelMergeInterp('weighted_sum');
|
||||
}
|
||||
}, []);
|
||||
|
||||
const valueModelOne = useMemo(() => optionsModelOne.find((o) => o.value === modelOne), [modelOne, optionsModelOne]);
|
||||
const valueModelTwo = useMemo(() => optionsModelTwo.find((o) => o.value === modelTwo), [modelTwo, optionsModelTwo]);
|
||||
const valueModelThree = useMemo(
|
||||
() => optionsModelThree.find((o) => o.value === modelThree),
|
||||
[modelThree, optionsModelThree]
|
||||
);
|
||||
|
||||
const handleChangeMergedModelName = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => setMergedModelName(e.target.value),
|
||||
[]
|
||||
);
|
||||
const handleChangeModelMergeAlpha = useCallback((v: number) => setModelMergeAlpha(v), []);
|
||||
const handleResetModelMergeAlpha = useCallback(() => setModelMergeAlpha(0.5), []);
|
||||
const handleChangeMergeInterp = useCallback((v: MergeInterpolationMethods) => setModelMergeInterp(v), []);
|
||||
const handleChangeMergeSaveLocType = useCallback((v: 'root' | 'custom') => setModelMergeSaveLocType(v), []);
|
||||
const handleChangeMergeCustomSaveLoc = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => setModelMergeCustomSaveLoc(e.target.value),
|
||||
[]
|
||||
);
|
||||
const handleChangeModelMergeForce = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => setModelMergeForce(e.target.checked),
|
||||
[]
|
||||
);
|
||||
|
||||
const mergeModelsHandler = useCallback(() => {
|
||||
const models_names: string[] = [];
|
||||
|
||||
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
|
||||
modelsToMerge = modelsToMerge.filter((model) => model !== null);
|
||||
modelsToMerge.forEach((model) => {
|
||||
const n = model?.split('/')?.[2];
|
||||
if (n) {
|
||||
models_names.push(n);
|
||||
}
|
||||
});
|
||||
|
||||
const mergeModelsInfo: MergeModelConfig['body'] = {
|
||||
model_names: models_names,
|
||||
merged_model_name: mergedModelName !== '' ? mergedModelName : models_names.join('-'),
|
||||
alpha: modelMergeAlpha,
|
||||
interp: modelMergeInterp,
|
||||
force: modelMergeForce,
|
||||
merge_dest_directory: modelMergeSaveLocType === 'root' ? undefined : modelMergeCustomSaveLoc,
|
||||
};
|
||||
|
||||
mergeModels({
|
||||
base_model: baseModel,
|
||||
body: { body: mergeModelsInfo },
|
||||
})
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelsMerged'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelsMergeFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [
|
||||
baseModel,
|
||||
dispatch,
|
||||
mergeModels,
|
||||
mergedModelName,
|
||||
modelMergeAlpha,
|
||||
modelMergeCustomSaveLoc,
|
||||
modelMergeForce,
|
||||
modelMergeInterp,
|
||||
modelMergeSaveLocType,
|
||||
modelOne,
|
||||
modelThree,
|
||||
modelTwo,
|
||||
t,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Flex flexDir="column" gap={1}>
|
||||
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
|
||||
<Text fontSize="sm" variant="subtext">
|
||||
{t('modelManager.modelMergeHeaderHelp2')}
|
||||
</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex columnGap={4}>
|
||||
<FormControl w="full">
|
||||
<FormLabel>{t('modelManager.modelType')}</FormLabel>
|
||||
<Combobox options={baseModelTypeSelectOptions} value={valueBaseModel} onChange={onChangeBaseModel} />
|
||||
</FormControl>
|
||||
<FormControl w="full">
|
||||
<FormLabel>{t('modelManager.modelOne')}</FormLabel>
|
||||
<Combobox options={optionsModelOne} value={valueModelOne} onChange={onChangeModelOne} />
|
||||
</FormControl>
|
||||
<FormControl w="full">
|
||||
<FormLabel>{t('modelManager.modelTwo')}</FormLabel>
|
||||
<Combobox options={optionsModelTwo} value={valueModelTwo} onChange={onChangeModelTwo} />
|
||||
</FormControl>
|
||||
<FormControl w="full">
|
||||
<FormLabel>{t('modelManager.modelThree')}</FormLabel>
|
||||
<Combobox options={optionsModelThree} value={valueModelThree} onChange={onChangeModelThree} isClearable />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.mergedModelName')}</FormLabel>
|
||||
<Input value={mergedModelName} onChange={handleChangeMergedModelName} />
|
||||
</FormControl>
|
||||
|
||||
<Flex flexDirection="column" padding={4} borderRadius="base" gap={4} bg="base.800">
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.alpha')}</FormLabel>
|
||||
<CompositeSlider
|
||||
min={0.01}
|
||||
max={0.99}
|
||||
step={0.01}
|
||||
value={modelMergeAlpha}
|
||||
onChange={handleChangeModelMergeAlpha}
|
||||
onReset={handleResetModelMergeAlpha}
|
||||
marks
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
min={0.01}
|
||||
max={0.99}
|
||||
step={0.01}
|
||||
value={modelMergeAlpha}
|
||||
onChange={handleChangeModelMergeAlpha}
|
||||
onReset={handleResetModelMergeAlpha}
|
||||
/>
|
||||
<FormHelperText>{t('modelManager.modelMergeAlphaHelp')}</FormHelperText>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
|
||||
<Flex padding={4} gap={4} borderRadius="base" bg="base.800">
|
||||
<Text fontSize="sm" variant="subtext">
|
||||
{t('modelManager.interpolationType')}
|
||||
</Text>
|
||||
<RadioGroup value={modelMergeInterp} onChange={handleChangeMergeInterp}>
|
||||
<Flex columnGap={4}>
|
||||
{modelThree === null ? (
|
||||
<>
|
||||
<Radio value="weighted_sum">
|
||||
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
|
||||
</Radio>
|
||||
<Radio value="sigmoid">
|
||||
<Text fontSize="sm">{t('modelManager.sigmoid')}</Text>
|
||||
</Radio>
|
||||
<Radio value="inv_sigmoid">
|
||||
<Text fontSize="sm">{t('modelManager.inverseSigmoid')}</Text>
|
||||
</Radio>
|
||||
</>
|
||||
) : (
|
||||
<Radio value="add_difference">
|
||||
<Tooltip label={t('modelManager.modelMergeInterpAddDifferenceHelp')}>
|
||||
<Text fontSize="sm">{t('modelManager.addDifference')}</Text>
|
||||
</Tooltip>
|
||||
</Radio>
|
||||
)}
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDirection="column" padding={4} borderRadius="base" gap={4} bg="base.900">
|
||||
<Flex columnGap={4}>
|
||||
<Text fontSize="sm" variant="subtext">
|
||||
{t('modelManager.mergedModelSaveLocation')}
|
||||
</Text>
|
||||
<RadioGroup value={modelMergeSaveLocType} onChange={handleChangeMergeSaveLocType}>
|
||||
<Flex columnGap={4}>
|
||||
<Radio value="root">
|
||||
<Text fontSize="sm">{t('modelManager.invokeAIFolder')}</Text>
|
||||
</Radio>
|
||||
|
||||
<Radio value="custom">
|
||||
<Text fontSize="sm">{t('modelManager.custom')}</Text>
|
||||
</Radio>
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
|
||||
{modelMergeSaveLocType === 'custom' && (
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.mergedModelCustomSaveLocation')}</FormLabel>
|
||||
<Input value={modelMergeCustomSaveLoc} onChange={handleChangeMergeCustomSaveLoc} />
|
||||
</FormControl>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.ignoreMismatch')}</FormLabel>
|
||||
<Checkbox isChecked={modelMergeForce} onChange={handleChangeModelMergeForce} />
|
||||
</FormControl>
|
||||
|
||||
<Button onClick={mergeModelsHandler} isLoading={isLoading} isDisabled={modelOne === null || modelTwo === null}>
|
||||
{t('modelManager.merge')}
|
||||
</Button>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(MergeModelsPanel);
|
@ -1,63 +0,0 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { memo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig, LoRAConfig, MainModelConfig } from 'services/api/types';
|
||||
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
|
||||
const ModelManagerPanel = () => {
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||
const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
});
|
||||
const { loraModel } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
});
|
||||
|
||||
const model = mainModel ? mainModel : loraModel;
|
||||
|
||||
return (
|
||||
<Flex gap={8} w="full" h="full">
|
||||
<ModelList selectedModelId={selectedModelId} setSelectedModelId={setSelectedModelId} />
|
||||
<ModelEdit model={model} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type ModelEditProps = {
|
||||
model: MainModelConfig | LoRAConfig | undefined;
|
||||
};
|
||||
|
||||
const ModelEdit = (props: ModelEditProps) => {
|
||||
const { t } = useTranslation();
|
||||
const { model } = props;
|
||||
|
||||
if (model?.format === 'checkpoint') {
|
||||
return <CheckpointModelEdit key={model.key} model={model} />;
|
||||
}
|
||||
|
||||
if (model?.format === 'diffusers') {
|
||||
return <DiffusersModelEdit key={model.key} model={model as DiffusersModelConfig} />;
|
||||
}
|
||||
|
||||
if (model?.type === 'lora') {
|
||||
return <LoRAModelEdit key={model.key} model={model} />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" justifyContent="center" alignItems="center" maxH={96} userSelect="none">
|
||||
<Text variant="subtext">{t('modelManager.noModelSelected')}</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelManagerPanel);
|
@ -1,185 +0,0 @@
|
||||
import {
|
||||
Badge,
|
||||
Button,
|
||||
Checkbox,
|
||||
Divider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormErrorMessage,
|
||||
FormLabel,
|
||||
Input,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
||||
import CheckpointConfigsSelect from 'features/modelManager/subpanels/shared/CheckpointConfigsSelect';
|
||||
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetCheckpointConfigsQuery, useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelConvert from './ModelConvert';
|
||||
|
||||
type CheckpointModelEditProps = {
|
||||
model: CheckpointModelConfig;
|
||||
};
|
||||
|
||||
const CheckpointModelEdit = (props: CheckpointModelEditProps) => {
|
||||
const { model } = props;
|
||||
|
||||
const [updateModel, { isLoading }] = useUpdateModelsMutation();
|
||||
const { data: availableCheckpointConfigs } = useGetCheckpointConfigsQuery();
|
||||
|
||||
const [useCustomConfig, setUseCustomConfig] = useState<boolean>(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!availableCheckpointConfigs?.includes(model.config)) {
|
||||
setUseCustomConfig(true);
|
||||
}
|
||||
}, [availableCheckpointConfigs, model.config]);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<CheckpointModelConfig>({
|
||||
defaultValues: {
|
||||
name: model.name ? model.name : '',
|
||||
base: model.base,
|
||||
type: 'main',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
format: 'checkpoint',
|
||||
vae: model.vae ? model.vae : '',
|
||||
config: model.config ? model.config : '',
|
||||
variant: model.variant,
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const handleChangeUseCustomConfig = useCallback(() => setUseCustomConfig((prev) => !prev), []);
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<CheckpointModelConfig>>(
|
||||
(values) => {
|
||||
const responseBody = {
|
||||
key: model.key,
|
||||
body: values,
|
||||
};
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload as CheckpointModelConfig, { keepDefaultValues: true });
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((_) => {
|
||||
reset();
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
},
|
||||
[dispatch, model.key, reset, t, updateModel]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{model.name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
|
||||
</Text>
|
||||
</Flex>
|
||||
{![''].includes(model.base) ? (
|
||||
<ModelConvert model={model} />
|
||||
) : (
|
||||
<Badge p={2} borderRadius={4} bg="error.400">
|
||||
{t('modelManager.conversionNotSupported')}
|
||||
</Badge>
|
||||
)}
|
||||
</Flex>
|
||||
<Divider />
|
||||
|
||||
<Flex flexDirection="column" maxHeight={window.innerHeight - 270} overflowY="scroll">
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<BaseModelSelect<CheckpointModelConfig> control={control} name="base" />
|
||||
<ModelVariantSelect<CheckpointModelConfig> control={control} name="variant" />
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
{!useCustomConfig ? (
|
||||
<CheckpointConfigsSelect control={control} name="config" />
|
||||
) : (
|
||||
<FormControl isRequired>
|
||||
<FormLabel>{t('modelManager.config')}</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.useCustomConfig')}</FormLabel>
|
||||
<Checkbox isChecked={useCustomConfig} onChange={handleChangeUseCustomConfig} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
|
||||
<Button type="submit" isLoading={isLoading}>
|
||||
{t('modelManager.updateModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CheckpointModelEdit);
|
@ -1,133 +0,0 @@
|
||||
import { Button, Divider, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
||||
import ModelVariantSelect from 'features/modelManager/subpanels/shared/ModelVariantSelect';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { DiffusersModelConfig } from 'services/api/types';
|
||||
|
||||
type DiffusersModelEditProps = {
|
||||
model: DiffusersModelConfig;
|
||||
};
|
||||
|
||||
const DiffusersModelEdit = (props: DiffusersModelEditProps) => {
|
||||
const { model } = props;
|
||||
|
||||
const [updateModel, { isLoading }] = useUpdateModelsMutation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<DiffusersModelConfig>({
|
||||
defaultValues: {
|
||||
name: model.name ? model.name : '',
|
||||
base: model.base,
|
||||
type: 'main',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
format: 'diffusers',
|
||||
vae: model.vae ? model.vae : '',
|
||||
variant: model.variant,
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<DiffusersModelConfig>>(
|
||||
(values) => {
|
||||
const responseBody = {
|
||||
key: model.key,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload as DiffusersModelConfig, { keepDefaultValues: true });
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((_) => {
|
||||
reset();
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
},
|
||||
[dispatch, model.key, reset, t, updateModel]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{model.name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Divider />
|
||||
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<BaseModelSelect<DiffusersModelConfig> control={control} name="base" />
|
||||
<ModelVariantSelect<DiffusersModelConfig> control={control} name="variant" />
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.vaeLocation')}</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
<Button type="submit" isLoading={isLoading}>
|
||||
{t('modelManager.updateModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(DiffusersModelEdit);
|
@ -1,127 +0,0 @@
|
||||
import { Button, Divider, Flex, FormControl, FormErrorMessage, FormLabel, Input, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import BaseModelSelect from 'features/modelManager/subpanels/shared/BaseModelSelect';
|
||||
import { LORA_MODEL_FORMAT_MAP, MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useUpdateModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { LoRAModelConfig } from 'services/api/types';
|
||||
|
||||
type LoRAModelEditProps = {
|
||||
model: LoRAModelConfig;
|
||||
};
|
||||
|
||||
const LoRAModelEdit = (props: LoRAModelEditProps) => {
|
||||
const { model } = props;
|
||||
|
||||
const [updateModel, { isLoading }] = useUpdateModelsMutation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<LoRAModelConfig>({
|
||||
defaultValues: {
|
||||
name: model.name ? model.name : '',
|
||||
base: model.base,
|
||||
type: 'lora',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
format: model.format,
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const onSubmit = useCallback<SubmitHandler<LoRAModelConfig>>(
|
||||
(values) => {
|
||||
const responseBody = {
|
||||
key: model.key,
|
||||
body: values,
|
||||
};
|
||||
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload as LoRAModelConfig, { keepDefaultValues: true });
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((_) => {
|
||||
reset();
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
},
|
||||
[dispatch, model.key, reset, t, updateModel]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{model.name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[model.base]} {t('modelManager.model')} ⋅ {LORA_MODEL_FORMAT_MAP[model.format]}{' '}
|
||||
{t('common.format')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Divider />
|
||||
|
||||
<form onSubmit={handleSubmit(onSubmit)}>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<FormControl isInvalid={Boolean(errors.name)}>
|
||||
<FormLabel>{t('modelManager.name')}</FormLabel>
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
/>
|
||||
{errors.name?.message && <FormErrorMessage>{errors.name?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.description')}</FormLabel>
|
||||
<Input {...register('description')} />
|
||||
</FormControl>
|
||||
<BaseModelSelect<LoRAModelConfig> control={control} name="base" />
|
||||
|
||||
<FormControl isInvalid={Boolean(errors.path)}>
|
||||
<FormLabel>{t('modelManager.modelLocation')}</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
{errors.path?.message && <FormErrorMessage>{errors.path?.message}</FormErrorMessage>}
|
||||
</FormControl>
|
||||
<Button type="submit" isLoading={isLoading}>
|
||||
{t('modelManager.updateModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoRAModelEdit);
|
@ -1,165 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
ConfirmationAlertDialog,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Input,
|
||||
ListItem,
|
||||
Radio,
|
||||
RadioGroup,
|
||||
Text,
|
||||
Tooltip,
|
||||
UnorderedList,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
interface ModelConvertProps {
|
||||
model: CheckpointModelConfig;
|
||||
}
|
||||
|
||||
type SaveLocation = 'InvokeAIRoot' | 'Custom';
|
||||
|
||||
const ModelConvert = (props: ModelConvertProps) => {
|
||||
const { model } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const [saveLocation, setSaveLocation] = useState<SaveLocation>('InvokeAIRoot');
|
||||
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
||||
|
||||
useEffect(() => {
|
||||
setSaveLocation('InvokeAIRoot');
|
||||
}, [model]);
|
||||
|
||||
const modelConvertCancelHandler = useCallback(() => {
|
||||
setSaveLocation('InvokeAIRoot');
|
||||
}, []);
|
||||
|
||||
const handleChangeSaveLocation = useCallback((v: string) => {
|
||||
setSaveLocation(v as SaveLocation);
|
||||
}, []);
|
||||
const handleChangeCustomSaveLocation = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setCustomSaveLocation(e.target.value);
|
||||
}, []);
|
||||
|
||||
const modelConvertHandler = useCallback(() => {
|
||||
const queryArg = {
|
||||
base_model: model.base,
|
||||
model_name: model.name,
|
||||
convert_dest_directory: saveLocation === 'Custom' ? customSaveLocation : undefined,
|
||||
};
|
||||
|
||||
if (saveLocation === 'Custom' && customSaveLocation === '') {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.noCustomLocationProvided'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.convertingModelBegin')}: ${model.name}`,
|
||||
status: 'info',
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
convertModel(queryArg)
|
||||
.unwrap()
|
||||
.then(() => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelConverted')}: ${model.name}`,
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch(() => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelConversionFailed')}: ${model.name}`,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
}, [convertModel, customSaveLocation, dispatch, model.base, model.name, saveLocation, t]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
onClick={onOpen}
|
||||
size="sm"
|
||||
aria-label={t('modelManager.convertToDiffusers')}
|
||||
className=" modal-close-btn"
|
||||
isLoading={isLoading}
|
||||
>
|
||||
🧨 {t('modelManager.convertToDiffusers')}
|
||||
</Button>
|
||||
<ConfirmationAlertDialog
|
||||
title={`${t('modelManager.convert')} ${model.name}`}
|
||||
acceptCallback={modelConvertHandler}
|
||||
cancelCallback={modelConvertCancelHandler}
|
||||
acceptButtonText={`${t('modelManager.convert')}`}
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
>
|
||||
<Flex flexDirection="column" rowGap={4}>
|
||||
<Text>{t('modelManager.convertToDiffusersHelpText1')}</Text>
|
||||
<UnorderedList>
|
||||
<ListItem>{t('modelManager.convertToDiffusersHelpText2')}</ListItem>
|
||||
<ListItem>{t('modelManager.convertToDiffusersHelpText3')}</ListItem>
|
||||
<ListItem>{t('modelManager.convertToDiffusersHelpText4')}</ListItem>
|
||||
<ListItem>{t('modelManager.convertToDiffusersHelpText5')}</ListItem>
|
||||
</UnorderedList>
|
||||
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Flex marginTop={4} flexDir="column" gap={2}>
|
||||
<Text fontWeight="semibold">{t('modelManager.convertToDiffusersSaveLocation')}</Text>
|
||||
<RadioGroup value={saveLocation} onChange={handleChangeSaveLocation}>
|
||||
<Flex gap={4}>
|
||||
<Radio value="InvokeAIRoot">
|
||||
<Tooltip label="Save converted model in the InvokeAI root folder">
|
||||
{t('modelManager.invokeRoot')}
|
||||
</Tooltip>
|
||||
</Radio>
|
||||
<Radio value="Custom">
|
||||
<Tooltip label="Save converted model in a custom folder">{t('modelManager.custom')}</Tooltip>
|
||||
</Radio>
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex>
|
||||
{saveLocation === 'Custom' && (
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.customSaveLocation')}</FormLabel>
|
||||
<Input width="full" value={customSaveLocation} onChange={handleChangeCustomSaveLocation} />
|
||||
</FormControl>
|
||||
)}
|
||||
</Flex>
|
||||
</ConfirmationAlertDialog>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelConvert);
|
@ -1,205 +0,0 @@
|
||||
import { Button, ButtonGroup, Flex, FormControl, FormLabel, Input, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import { forEach } from 'lodash-es';
|
||||
import type { ChangeEvent, PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
// import type { LoRAConfig, MainModelConfig } from 'services/api/endpoints/models';
|
||||
import { useGetLoRAModelsQuery, useGetMainModelsQuery } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
|
||||
|
||||
import ModelListItem from './ModelListItem';
|
||||
|
||||
type ModelListProps = {
|
||||
selectedModelId: string | undefined;
|
||||
setSelectedModelId: (name: string | undefined) => void;
|
||||
};
|
||||
|
||||
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
|
||||
|
||||
type ModelType = 'main' | 'lora';
|
||||
|
||||
type CombinedModelFormat = ModelFormat | 'lora';
|
||||
|
||||
const ModelList = (props: ModelListProps) => {
|
||||
const { selectedModelId, setSelectedModelId } = props;
|
||||
const { t } = useTranslation();
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
const [modelFormatFilter, setModelFormatFilter] = useState<CombinedModelFormat>('all');
|
||||
|
||||
const { filteredDiffusersModels, isLoadingDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredDiffusersModels: modelsFilter(data, 'main', 'diffusers', nameFilter),
|
||||
isLoadingDiffusersModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredCheckpointModels, isLoadingCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredCheckpointModels: modelsFilter(data, 'main', 'checkpoint', nameFilter),
|
||||
isLoadingCheckpointModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredLoraModels, isLoadingLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||
isLoadingLoraModels: isLoading,
|
||||
}),
|
||||
});
|
||||
|
||||
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
|
||||
setNameFilter(e.target.value);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
|
||||
<Flex flexDirection="column" gap={4} paddingInlineEnd={4}>
|
||||
<ButtonGroup>
|
||||
<Button onClick={setModelFormatFilter.bind(null, 'all')} isChecked={modelFormatFilter === 'all'} size="sm">
|
||||
{t('modelManager.allModels')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={setModelFormatFilter.bind(null, 'diffusers')}
|
||||
isChecked={modelFormatFilter === 'diffusers'}
|
||||
>
|
||||
{t('modelManager.diffusersModels')}
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={setModelFormatFilter.bind(null, 'checkpoint')}
|
||||
isChecked={modelFormatFilter === 'checkpoint'}
|
||||
>
|
||||
{t('modelManager.checkpointModels')}
|
||||
</Button>
|
||||
<Button size="sm" onClick={setModelFormatFilter.bind(null, 'lora')} isChecked={modelFormatFilter === 'lora'}>
|
||||
{t('modelManager.loraModels')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.search')}</FormLabel>
|
||||
<Input onChange={handleSearchFilter} />
|
||||
</FormControl>
|
||||
|
||||
<Flex flexDirection="column" gap={4} maxHeight={window.innerHeight - 280} overflow="scroll">
|
||||
{/* Diffusers List */}
|
||||
{isLoadingDiffusersModels && <FetchingModelsLoader loadingMessage="Loading Diffusers..." />}
|
||||
{['all', 'diffusers'].includes(modelFormatFilter) &&
|
||||
!isLoadingDiffusersModels &&
|
||||
filteredDiffusersModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Diffusers"
|
||||
modelList={filteredDiffusersModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="diffusers"
|
||||
/>
|
||||
)}
|
||||
{/* Checkpoints List */}
|
||||
{isLoadingCheckpointModels && <FetchingModelsLoader loadingMessage="Loading Checkpoints..." />}
|
||||
{['all', 'checkpoint'].includes(modelFormatFilter) &&
|
||||
!isLoadingCheckpointModels &&
|
||||
filteredCheckpointModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Checkpoints"
|
||||
modelList={filteredCheckpointModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="checkpoints"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* LoRAs List */}
|
||||
{isLoadingLoraModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
|
||||
{['all', 'lora'].includes(modelFormatFilter) && !isLoadingLoraModels && filteredLoraModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="LoRAs"
|
||||
modelList={filteredLoraModels}
|
||||
selected={{ selectedModelId, setSelectedModelId }}
|
||||
key="loras"
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelList);
|
||||
|
||||
const modelsFilter = <T extends MainModelConfig | LoRAConfig>(
|
||||
data: EntityState<T, string> | undefined,
|
||||
model_type: ModelType,
|
||||
model_format: ModelFormat | undefined,
|
||||
nameFilter: string
|
||||
) => {
|
||||
const filteredModels: T[] = [];
|
||||
forEach(data?.entities, (model) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
|
||||
|
||||
const matchesFormat = model_format === undefined || model.format === model_format;
|
||||
const matchesType = model.type === model_type;
|
||||
|
||||
if (matchesFilter && matchesFormat && matchesType) {
|
||||
filteredModels.push(model);
|
||||
}
|
||||
});
|
||||
return filteredModels;
|
||||
};
|
||||
|
||||
const StyledModelContainer = memo((props: PropsWithChildren) => {
|
||||
return (
|
||||
<Flex flexDirection="column" gap={4} borderRadius={4} p={4} bg="base.800">
|
||||
{props.children}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
StyledModelContainer.displayName = 'StyledModelContainer';
|
||||
|
||||
type ModelListWrapperProps = {
|
||||
title: string;
|
||||
modelList: MainModelConfig[] | LoRAConfig[];
|
||||
selected: ModelListProps;
|
||||
};
|
||||
|
||||
const ModelListWrapper = memo((props: ModelListWrapperProps) => {
|
||||
const { title, modelList, selected } = props;
|
||||
return (
|
||||
<StyledModelContainer>
|
||||
<Flex gap={2} flexDir="column">
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
{title}
|
||||
</Text>
|
||||
{modelList.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.key}
|
||||
model={model}
|
||||
isSelected={selected.selectedModelId === model.key}
|
||||
setSelectedModelId={selected.setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
);
|
||||
});
|
||||
|
||||
ModelListWrapper.displayName = 'ModelListWrapper';
|
||||
|
||||
const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
|
||||
return (
|
||||
<StyledModelContainer>
|
||||
<Flex justifyContent="center" alignItems="center" flexDirection="column" p={4} gap={8}>
|
||||
<Spinner />
|
||||
<Text variant="subtext">{loadingMessage ? loadingMessage : 'Fetching...'}</Text>
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
);
|
||||
});
|
||||
|
||||
FetchingModelsLoader.displayName = 'FetchingModelsLoader';
|
@ -1,111 +0,0 @@
|
||||
import {
|
||||
Badge,
|
||||
Button,
|
||||
ConfirmationAlertDialog,
|
||||
Flex,
|
||||
IconButton,
|
||||
Text,
|
||||
Tooltip,
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
|
||||
import type { LoRAConfig, MainModelConfig } from 'services/api/types';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: MainModelConfig | LoRAConfig;
|
||||
isSelected: boolean;
|
||||
setSelectedModelId: (v: string | undefined) => void;
|
||||
};
|
||||
|
||||
const ModelListItem = (props: ModelListItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const [deleteModel] = useDeleteModelsMutation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
|
||||
const { model, isSelected, setSelectedModelId } = props;
|
||||
|
||||
const handleSelectModel = useCallback(() => {
|
||||
setSelectedModelId(model.key);
|
||||
}, [model.key, setSelectedModelId]);
|
||||
|
||||
const handleModelDelete = useCallback(() => {
|
||||
deleteModel({ key: model.key })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelDeleted')}: ${model.name}`,
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.modelDeleteFailed')}: ${model.name}`,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
setSelectedModelId(undefined);
|
||||
}, [deleteModel, model, setSelectedModelId, dispatch, t]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<Flex
|
||||
as={Button}
|
||||
isChecked={isSelected}
|
||||
variant={isSelected ? 'solid' : 'ghost'}
|
||||
justifyContent="start"
|
||||
p={2}
|
||||
borderRadius="base"
|
||||
w="full"
|
||||
alignItems="center"
|
||||
onClick={handleSelectModel}
|
||||
>
|
||||
<Flex gap={4} alignItems="center">
|
||||
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
|
||||
{MODEL_TYPE_SHORT_MAP[model.base as keyof typeof MODEL_TYPE_SHORT_MAP]}
|
||||
</Badge>
|
||||
<Tooltip label={model.description} placement="bottom">
|
||||
<Text>{model.name}</Text>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<IconButton
|
||||
onClick={onOpen}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
aria-label={t('modelManager.deleteConfig')}
|
||||
colorScheme="error"
|
||||
/>
|
||||
<ConfirmationAlertDialog
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
title={t('modelManager.deleteModel')}
|
||||
acceptCallback={handleModelDelete}
|
||||
acceptButtonText={t('modelManager.delete')}
|
||||
>
|
||||
<Flex rowGap={4} flexDirection="column">
|
||||
<Text fontWeight="bold">{t('modelManager.deleteMsg1')}</Text>
|
||||
<Text>{t('modelManager.deleteMsg2')}</Text>
|
||||
</Flex>
|
||||
</ConfirmationAlertDialog>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelListItem);
|
@ -1,14 +0,0 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
|
||||
import SyncModels from './ModelManagerSettingsPanel/SyncModels';
|
||||
|
||||
const ModelManagerSettingsPanel = () => {
|
||||
return (
|
||||
<Flex>
|
||||
<SyncModels />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ModelManagerSettingsPanel);
|
@ -1,22 +0,0 @@
|
||||
import { Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { SyncModelsButton } from 'features/modelManager/components/SyncModels/SyncModelsButton';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const SyncModels = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex w="full" p={4} borderRadius={4} gap={4} justifyContent="space-between" alignItems="center" bg="base.800">
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
<Text fontWeight="semibold">{t('modelManager.syncModels')}</Text>
|
||||
<Text fontSize="sm" variant="subtext">
|
||||
{t('modelManager.syncModelsDesc')}
|
||||
</Text>
|
||||
</Flex>
|
||||
<SyncModelsButton />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SyncModels);
|
@ -1,36 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||
];
|
||||
|
||||
const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default typedMemo(BaseModelSelect);
|
@ -1,32 +0,0 @@
|
||||
import type { ChakraProps, ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useController, type UseControllerProps } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetCheckpointConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { CheckpointModelConfig } from 'services/api/types';
|
||||
|
||||
const sx: ChakraProps['sx'] = { w: 'full' };
|
||||
|
||||
const CheckpointConfigsSelect = (props: UseControllerProps<CheckpointModelConfig>) => {
|
||||
const { data } = useGetCheckpointConfigsQuery();
|
||||
const { t } = useTranslation();
|
||||
const options = useMemo<ComboboxOption[]>(() => (data ? data.map((i) => ({ label: i, value: i })) : []), [data]);
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value, options]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.configFile')}</FormLabel>
|
||||
<Combobox placeholder="Select A Config File" value={value} options={options} onChange={onChange} sx={sx} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CheckpointConfigsSelect);
|
@ -1,34 +0,0 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'normal', label: 'Normal' },
|
||||
{ value: 'inpaint', label: 'Inpaint' },
|
||||
{ value: 'depth', label: 'Depth' },
|
||||
];
|
||||
|
||||
const ModelVariantSelect = <T extends CheckpointModelConfig | DiffusersModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { t } = useTranslation();
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return (
|
||||
<FormControl>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default typedMemo(ModelVariantSelect);
|
@ -1,7 +1,7 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
||||
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { MainModelFieldInputInstance, MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
||||
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||
import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
SDXLRefinerModelFieldInputInstance,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
||||
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SDXLMainModelFieldInputInstance, SDXLMainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
||||
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||
import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { VAEModelFieldInputInstance, VAEModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { pick } from 'lodash-es';
|
||||
|
@ -16,7 +16,7 @@ import { EMPTY_ARRAY } from 'app/store/util';
|
||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||
import LoRASelect from 'features/lora/components/LoRASelect';
|
||||
import { selectLoraSlice } from 'features/lora/store/loraSlice';
|
||||
import { SyncModelsIconButton } from 'features/modelManager/components/SyncModels/SyncModelsIconButton';
|
||||
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
|
||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
import ParamScheduler from 'features/parameters/components/Core/ParamScheduler';
|
||||
import ParamSteps from 'features/parameters/components/Core/ParamSteps';
|
||||
|
@ -10,7 +10,6 @@ import type {
|
||||
IPAdapterModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
MergeModelConfig,
|
||||
T2IAdapterModelConfig,
|
||||
TextualInversionModelConfig,
|
||||
VAEModelConfig,
|
||||
@ -38,22 +37,9 @@ type DeleteMainModelArg = {
|
||||
|
||||
type DeleteMainModelResponse = void;
|
||||
|
||||
type ConvertMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
convert_dest_directory?: string;
|
||||
};
|
||||
|
||||
type ConvertMainModelResponse =
|
||||
paths['/api/v2/models/convert/{key}']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type MergeMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
body: MergeModelConfig;
|
||||
};
|
||||
|
||||
type MergeMainModelResponse = paths['/api/v2/models/merge']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ImportMainModelArg = {
|
||||
source: NonNullable<operations['heuristic_install_model']['parameters']['query']['source']>;
|
||||
access_token?: operations['heuristic_install_model']['parameters']['query']['access_token'];
|
||||
@ -72,20 +58,11 @@ type DeleteImportModelsResponse =
|
||||
type PruneModelImportsResponse =
|
||||
paths['/api/v2/models/import']['patch']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ImportAdvancedModelArg = {
|
||||
source: NonNullable<operations['import_model']['requestBody']['content']['application/json']['source']>;
|
||||
config: NonNullable<operations['import_model']['requestBody']['content']['application/json']['config']>;
|
||||
};
|
||||
|
||||
type ImportAdvancedModelResponse = paths['/api/v2/models/import']['post']['responses']['201']['content']['application/json'];
|
||||
|
||||
export type ScanFolderResponse =
|
||||
paths['/api/v2/models/scan_folder']['get']['responses']['200']['content']['application/json'];
|
||||
type ScanFolderArg = operations['scan_for_models']['parameters']['query'];
|
||||
|
||||
type CheckpointConfigsResponse =
|
||||
paths['/api/v2/models/ckpt_confs']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
export const mainModelsAdapter = createEntityAdapter<MainModelConfig, string>({
|
||||
selectId: (entity) => entity.key,
|
||||
sortComparer: (a, b) => a.name.localeCompare(b.name),
|
||||
@ -199,16 +176,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model', 'ModelImports'],
|
||||
}),
|
||||
importAdvancedModel: build.mutation<ImportAdvancedModelResponse, ImportAdvancedModelArg>({
|
||||
query: ({ source, config}) => {
|
||||
return {
|
||||
url: buildModelsUrl('install'),
|
||||
method: 'POST',
|
||||
body: { source, config },
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model', 'ModelImports'],
|
||||
}),
|
||||
deleteModels: build.mutation<DeleteMainModelResponse, DeleteMainModelArg>({
|
||||
query: ({ key }) => {
|
||||
return {
|
||||
@ -218,25 +185,14 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
convertMainModels: build.mutation<ConvertMainModelResponse, ConvertMainModelArg>({
|
||||
query: ({ base_model, model_name, convert_dest_directory }) => {
|
||||
convertMainModels: build.mutation<ConvertMainModelResponse, string>({
|
||||
query: (key) => {
|
||||
return {
|
||||
url: buildModelsUrl(`convert/${base_model}/main/${model_name}`),
|
||||
url: buildModelsUrl(`convert/${key}`),
|
||||
method: 'PUT',
|
||||
params: { convert_dest_directory },
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
|
||||
query: ({ base_model, body }) => {
|
||||
return {
|
||||
url: buildModelsUrl(`merge/${base_model}`),
|
||||
method: 'PUT',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
invalidatesTags: ['ModelConfig'],
|
||||
}),
|
||||
getModelConfig: build.query<GetModelConfigResponse, string>({
|
||||
query: (key) => buildModelsUrl(`i/${key}`),
|
||||
@ -323,13 +279,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['ModelImports'],
|
||||
}),
|
||||
getCheckpointConfigs: build.query<CheckpointConfigsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
url: buildModelsUrl(`ckpt_confs`),
|
||||
};
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@ -345,13 +294,10 @@ export const {
|
||||
useDeleteModelsMutation,
|
||||
useUpdateModelsMutation,
|
||||
useImportMainModelsMutation,
|
||||
useImportAdvancedModelMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useMergeMainModelsMutation,
|
||||
useSyncModelsMutation,
|
||||
useScanModelsQuery,
|
||||
useLazyScanModelsQuery,
|
||||
useGetCheckpointConfigsQuery,
|
||||
useGetModelImportsQuery,
|
||||
useGetModelMetadataQuery,
|
||||
useDeleteModelImportMutation,
|
||||
|
Loading…
Reference in New Issue
Block a user