mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add LoRAs to the model manager.
This commit is contained in:
parent
bb9460d278
commit
10e1d623c3
@ -340,6 +340,7 @@
|
||||
"allModels": "All Models",
|
||||
"checkpointModels": "Checkpoints",
|
||||
"diffusersModels": "Diffusers",
|
||||
"loraModels": "LoRAs",
|
||||
"safetensorModels": "SafeTensors",
|
||||
"modelAdded": "Model Added",
|
||||
"modelUpdated": "Model Updated",
|
||||
|
@ -1,3 +1,5 @@
|
||||
import { components } from 'services/api/schema';
|
||||
|
||||
export const MODEL_TYPE_MAP = {
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
'sd-2': 'Stable Diffusion 2.x',
|
||||
@ -5,6 +7,13 @@ export const MODEL_TYPE_MAP = {
|
||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||
};
|
||||
|
||||
export const MODEL_TYPE_SHORT_MAP = {
|
||||
'sd-1': 'SD1',
|
||||
'sd-2': 'SD2',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
};
|
||||
|
||||
export const clipSkipMap = {
|
||||
'sd-1': {
|
||||
maxClip: 12,
|
||||
@ -23,3 +32,12 @@ export const clipSkipMap = {
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
},
|
||||
};
|
||||
|
||||
type LoRAModelFormatMap = {
|
||||
[key in components['schemas']['LoRAModelFormat']]: string;
|
||||
};
|
||||
|
||||
export const LORA_MODEL_FORMAT_MAP: LoRAModelFormatMap = {
|
||||
lycoris: 'LyCORIS',
|
||||
diffusers: 'Diffusers',
|
||||
};
|
||||
|
@ -3,20 +3,31 @@ import { Flex, Text } from '@chakra-ui/react';
|
||||
import { useState } from 'react';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
DiffusersModelConfigEntity,
|
||||
LoRAModelConfigEntity,
|
||||
useGetMainModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
} from 'services/api/endpoints/models';
|
||||
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
|
||||
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
|
||||
import LoRAModelEdit from './ModelManagerPanel/LoRAModelEdit';
|
||||
import ModelList from './ModelManagerPanel/ModelList';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
|
||||
export default function ModelManagerPanel() {
|
||||
const [selectedModelId, setSelectedModelId] = useState<string>();
|
||||
const { model } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
const { mainModel } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
mainModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
});
|
||||
const { loraModel } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
loraModel: selectedModelId ? data?.entities[selectedModelId] : undefined,
|
||||
}),
|
||||
});
|
||||
|
||||
const model = mainModel ? mainModel : loraModel;
|
||||
|
||||
return (
|
||||
<Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
|
||||
@ -30,7 +41,7 @@ export default function ModelManagerPanel() {
|
||||
}
|
||||
|
||||
type ModelEditProps = {
|
||||
model: MainModelConfigEntity | undefined;
|
||||
model: MainModelConfigEntity | LoRAModelConfigEntity | undefined;
|
||||
};
|
||||
|
||||
const ModelEdit = (props: ModelEditProps) => {
|
||||
@ -41,7 +52,16 @@ const ModelEdit = (props: ModelEditProps) => {
|
||||
}
|
||||
|
||||
if (model?.model_format === 'diffusers') {
|
||||
return <DiffusersModelEdit key={model.id} model={model} />;
|
||||
return (
|
||||
<DiffusersModelEdit
|
||||
key={model.id}
|
||||
model={model as DiffusersModelConfigEntity}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (model?.model_type === 'lora') {
|
||||
return <LoRAModelEdit key={model.id} model={model} />;
|
||||
}
|
||||
|
||||
return (
|
||||
|
@ -0,0 +1,82 @@
|
||||
import { Divider, Flex, Text } from '@chakra-ui/react';
|
||||
import { useForm } from '@mantine/form';
|
||||
import IAIMantineTextInput from 'common/components/IAIMantineInput';
|
||||
import {
|
||||
LORA_MODEL_FORMAT_MAP,
|
||||
MODEL_TYPE_MAP,
|
||||
} from 'features/parameters/types/constants';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { LoRAModelConfigEntity } from 'services/api/endpoints/models';
|
||||
import { LoRAModelConfig } from 'services/api/types';
|
||||
import BaseModelSelect from '../shared/BaseModelSelect';
|
||||
|
||||
type LoRAModelEditProps = {
|
||||
model: LoRAModelConfigEntity;
|
||||
};
|
||||
|
||||
export default function LoRAModelEdit(props: LoRAModelEditProps) {
|
||||
const { model } = props;
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const loraEditForm = useForm<LoRAModelConfig>({
|
||||
initialValues: {
|
||||
model_name: model.model_name ? model.model_name : '',
|
||||
base_model: model.base_model,
|
||||
model_type: 'lora',
|
||||
path: model.path ? model.path : '',
|
||||
description: model.description ? model.description : '',
|
||||
model_format: model.model_format,
|
||||
},
|
||||
validate: {
|
||||
path: (value) =>
|
||||
value.trim().length === 0 ? 'Must provide a path' : null,
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex flexDirection="column" rowGap={4} width="100%">
|
||||
<Flex flexDirection="column">
|
||||
<Text fontSize="lg" fontWeight="bold">
|
||||
{model.model_name}
|
||||
</Text>
|
||||
<Text fontSize="sm" color="base.400">
|
||||
{MODEL_TYPE_MAP[model.base_model]} Model ⋅{' '}
|
||||
{LORA_MODEL_FORMAT_MAP[model.model_format]} format
|
||||
</Text>
|
||||
</Flex>
|
||||
<Divider />
|
||||
|
||||
<form>
|
||||
<Flex flexDirection="column" overflowY="scroll" gap={4}>
|
||||
<IAIMantineTextInput
|
||||
label={t('modelManager.name')}
|
||||
readOnly={true}
|
||||
disabled={true}
|
||||
{...loraEditForm.getInputProps('model_name')}
|
||||
/>
|
||||
<IAIMantineTextInput
|
||||
label={t('modelManager.description')}
|
||||
readOnly={true}
|
||||
disabled={true}
|
||||
{...loraEditForm.getInputProps('description')}
|
||||
/>
|
||||
<BaseModelSelect
|
||||
readOnly={true}
|
||||
disabled={true}
|
||||
{...loraEditForm.getInputProps('base_model')}
|
||||
/>
|
||||
<IAIMantineTextInput
|
||||
readOnly={true}
|
||||
disabled={true}
|
||||
label={t('modelManager.modelLocation')}
|
||||
{...loraEditForm.getInputProps('path')}
|
||||
/>
|
||||
<Text color="base.400">
|
||||
{t('Editing LoRA model metadata is not yet supported.')}
|
||||
</Text>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -9,6 +9,8 @@ import { useTranslation } from 'react-i18next';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
useGetMainModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
LoRAModelConfigEntity,
|
||||
} from 'services/api/endpoints/models';
|
||||
import ModelListItem from './ModelListItem';
|
||||
import { ALL_BASE_MODELS } from 'services/api/constants';
|
||||
@ -20,22 +22,42 @@ type ModelListProps = {
|
||||
|
||||
type ModelFormat = 'images' | 'checkpoint' | 'diffusers';
|
||||
|
||||
type ModelType = 'main' | 'lora';
|
||||
|
||||
type CombinedModelFormat = ModelFormat | 'lora';
|
||||
|
||||
const ModelList = (props: ModelListProps) => {
|
||||
const { selectedModelId, setSelectedModelId } = props;
|
||||
const { t } = useTranslation();
|
||||
const [nameFilter, setNameFilter] = useState<string>('');
|
||||
const [modelFormatFilter, setModelFormatFilter] =
|
||||
useState<ModelFormat>('images');
|
||||
useState<CombinedModelFormat>('images');
|
||||
|
||||
const { filteredDiffusersModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
|
||||
filteredDiffusersModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'diffusers',
|
||||
nameFilter
|
||||
),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredCheckpointModels } = useGetMainModelsQuery(ALL_BASE_MODELS, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
|
||||
filteredCheckpointModels: modelsFilter(
|
||||
data,
|
||||
'main',
|
||||
'checkpoint',
|
||||
nameFilter
|
||||
),
|
||||
}),
|
||||
});
|
||||
|
||||
const { filteredLoraModels } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: ({ data }) => ({
|
||||
filteredLoraModels: modelsFilter(data, 'lora', undefined, nameFilter),
|
||||
}),
|
||||
});
|
||||
|
||||
@ -68,6 +90,13 @@ const ModelList = (props: ModelListProps) => {
|
||||
>
|
||||
{t('modelManager.checkpointModels')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
size="sm"
|
||||
onClick={() => setModelFormatFilter('lora')}
|
||||
isChecked={modelFormatFilter === 'lora'}
|
||||
>
|
||||
{t('modelManager.loraModels')}
|
||||
</IAIButton>
|
||||
</ButtonGroup>
|
||||
|
||||
<IAIInput
|
||||
@ -118,6 +147,24 @@ const ModelList = (props: ModelListProps) => {
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
{['images', 'lora'].includes(modelFormatFilter) &&
|
||||
filteredLoraModels.length > 0 && (
|
||||
<StyledModelContainer>
|
||||
<Flex sx={{ gap: 2, flexDir: 'column' }}>
|
||||
<Text variant="subtext" fontSize="sm">
|
||||
LoRAs
|
||||
</Text>
|
||||
{filteredLoraModels.map((model) => (
|
||||
<ModelListItem
|
||||
key={model.id}
|
||||
model={model}
|
||||
isSelected={selectedModelId === model.id}
|
||||
setSelectedModelId={setSelectedModelId}
|
||||
/>
|
||||
))}
|
||||
</Flex>
|
||||
</StyledModelContainer>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
@ -126,12 +173,13 @@ const ModelList = (props: ModelListProps) => {
|
||||
|
||||
export default ModelList;
|
||||
|
||||
const modelsFilter = (
|
||||
data: EntityState<MainModelConfigEntity> | undefined,
|
||||
model_format: ModelFormat,
|
||||
const modelsFilter = <T extends MainModelConfigEntity | LoRAModelConfigEntity>(
|
||||
data: EntityState<T> | undefined,
|
||||
model_type: ModelType,
|
||||
model_format: ModelFormat | undefined,
|
||||
nameFilter: string
|
||||
) => {
|
||||
const filteredModels: MainModelConfigEntity[] = [];
|
||||
const filteredModels: T[] = [];
|
||||
forEach(data?.entities, (model) => {
|
||||
if (!model) {
|
||||
return;
|
||||
@ -141,9 +189,11 @@ const modelsFilter = (
|
||||
.toLowerCase()
|
||||
.includes(nameFilter.toLowerCase());
|
||||
|
||||
const matchesFormat = model.model_format === model_format;
|
||||
const matchesFormat =
|
||||
model_format === undefined || model.model_format === model_format;
|
||||
const matchesType = model.model_type === model_type;
|
||||
|
||||
if (matchesFilter && matchesFormat) {
|
||||
if (matchesFilter && matchesFormat && matchesType) {
|
||||
filteredModels.push(model);
|
||||
}
|
||||
});
|
||||
|
@ -9,29 +9,26 @@ import { selectIsBusy } from 'features/system/store/systemSelectors';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
|
||||
import {
|
||||
MainModelConfigEntity,
|
||||
LoRAModelConfigEntity,
|
||||
useDeleteMainModelsMutation,
|
||||
useDeleteLoRAModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
type ModelListItemProps = {
|
||||
model: MainModelConfigEntity;
|
||||
model: MainModelConfigEntity | LoRAModelConfigEntity;
|
||||
isSelected: boolean;
|
||||
setSelectedModelId: (v: string | undefined) => void;
|
||||
};
|
||||
|
||||
const modelBaseTypeMap = {
|
||||
'sd-1': 'SD1',
|
||||
'sd-2': 'SD2',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
};
|
||||
|
||||
export default function ModelListItem(props: ModelListItemProps) {
|
||||
const isBusy = useAppSelector(selectIsBusy);
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const [deleteMainModel] = useDeleteMainModelsMutation();
|
||||
const [deleteLoRAModel] = useDeleteLoRAModelsMutation();
|
||||
|
||||
const { model, isSelected, setSelectedModelId } = props;
|
||||
|
||||
@ -40,7 +37,10 @@ export default function ModelListItem(props: ModelListItemProps) {
|
||||
}, [model.id, setSelectedModelId]);
|
||||
|
||||
const handleModelDelete = useCallback(() => {
|
||||
deleteMainModel(model)
|
||||
const method = { main: deleteMainModel, lora: deleteLoRAModel }[
|
||||
model.model_type
|
||||
];
|
||||
method(model)
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
@ -60,14 +60,21 @@ export default function ModelListItem(props: ModelListItemProps) {
|
||||
title: `${t('modelManager.modelDeleteFailed')}: ${
|
||||
model.model_name
|
||||
}`,
|
||||
status: 'success',
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
setSelectedModelId(undefined);
|
||||
}, [deleteMainModel, model, setSelectedModelId, dispatch, t]);
|
||||
}, [
|
||||
deleteMainModel,
|
||||
deleteLoRAModel,
|
||||
model,
|
||||
setSelectedModelId,
|
||||
dispatch,
|
||||
t,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
|
||||
@ -100,8 +107,8 @@ export default function ModelListItem(props: ModelListItemProps) {
|
||||
<Flex gap={4} alignItems="center">
|
||||
<Badge minWidth={14} p={0.5} fontSize="sm" variant="solid">
|
||||
{
|
||||
modelBaseTypeMap[
|
||||
model.base_model as keyof typeof modelBaseTypeMap
|
||||
MODEL_TYPE_SHORT_MAP[
|
||||
model.base_model as keyof typeof MODEL_TYPE_SHORT_MAP
|
||||
]
|
||||
}
|
||||
</Badge>
|
||||
|
@ -62,6 +62,10 @@ type DeleteMainModelArg = {
|
||||
|
||||
type DeleteMainModelResponse = void;
|
||||
|
||||
type DeleteLoRAModelArg = DeleteMainModelArg;
|
||||
|
||||
type DeleteLoRAModelResponse = void;
|
||||
|
||||
type ConvertMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
@ -320,6 +324,18 @@ export const modelsApi = api.injectEndpoints({
|
||||
);
|
||||
},
|
||||
}),
|
||||
deleteLoRAModels: build.mutation<
|
||||
DeleteLoRAModelResponse,
|
||||
DeleteLoRAModelArg
|
||||
>({
|
||||
query: ({ base_model, model_name }) => {
|
||||
return {
|
||||
url: `models/${base_model}/lora/${model_name}`,
|
||||
method: 'DELETE',
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'LoRAModel', id: LIST_TAG }],
|
||||
}),
|
||||
getControlNetModels: build.query<
|
||||
EntityState<ControlNetModelConfigEntity>,
|
||||
void
|
||||
@ -467,6 +483,7 @@ export const {
|
||||
useAddMainModelsMutation,
|
||||
useConvertMainModelsMutation,
|
||||
useMergeMainModelsMutation,
|
||||
useDeleteLoRAModelsMutation,
|
||||
useSyncModelsMutation,
|
||||
useGetModelsInFolderQuery,
|
||||
useGetCheckpointConfigsQuery,
|
||||
|
Loading…
Reference in New Issue
Block a user