mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
edit view for model, depending on type and valid values
This commit is contained in:
parent
6b68971f38
commit
0a69779df9
@ -6,6 +6,7 @@ import type { PersistConfig, RootState } from 'app/store/store';
|
||||
type ModelManagerState = {
|
||||
_version: 1;
|
||||
selectedModelKey: string | null;
|
||||
selectedModelMode: "edit" | "view",
|
||||
searchTerm: string;
|
||||
filteredModelType: string | null;
|
||||
};
|
||||
@ -13,6 +14,7 @@ type ModelManagerState = {
|
||||
export const initialModelManagerState: ModelManagerState = {
|
||||
_version: 1,
|
||||
selectedModelKey: null,
|
||||
selectedModelMode: "view",
|
||||
filteredModelType: null,
|
||||
searchTerm: ""
|
||||
};
|
||||
@ -22,8 +24,12 @@ export const modelManagerV2Slice = createSlice({
|
||||
initialState: initialModelManagerState,
|
||||
reducers: {
|
||||
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
|
||||
state.selectedModelMode = "view"
|
||||
state.selectedModelKey = action.payload;
|
||||
},
|
||||
setSelectedModelMode: (state, action: PayloadAction<"view" | "edit">) => {
|
||||
state.selectedModelMode = action.payload;
|
||||
},
|
||||
setSearchTerm: (state, action: PayloadAction<string>) => {
|
||||
state.searchTerm = action.payload;
|
||||
},
|
||||
@ -34,7 +40,7 @@ export const modelManagerV2Slice = createSlice({
|
||||
},
|
||||
});
|
||||
|
||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType } = modelManagerV2Slice.actions;
|
||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode } = modelManagerV2Slice.actions;
|
||||
|
||||
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
import { Box } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from '../../../app/store/storeHooks';
|
||||
import { ImportModels } from './ImportModels';
|
||||
import { ModelView } from './ModelPanel/ModelView';
|
||||
import { Model } from './ModelPanel/Model';
|
||||
|
||||
export const ModelPane = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
return (
|
||||
<Box layerStyle="first" p={2} borderRadius="base" w="full" h="full">
|
||||
{selectedModelKey ? <ModelView /> : <ImportModels />}
|
||||
{selectedModelKey ? <Model /> : <ImportModels />}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,29 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||
];
|
||||
|
||||
const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(BaseModelSelect);
|
@ -0,0 +1,27 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
{ value: true as any, label: 'True' },
|
||||
{ value: false as any, label: 'False' },
|
||||
];
|
||||
|
||||
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(BooleanSelect);
|
@ -0,0 +1,53 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { LORA_MODEL_FORMAT_MAP, MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||
];
|
||||
|
||||
const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field, formState } = useController(props);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const options: ComboboxOption[] = useMemo(() => {
|
||||
if (formState.defaultValues?.type === 'lora') {
|
||||
return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({
|
||||
value: format,
|
||||
label: LORA_MODEL_FORMAT_MAP[format],
|
||||
})) as ComboboxOption[];
|
||||
} else if (formState.defaultValues?.type === 'embedding') {
|
||||
return [
|
||||
{ value: 'embedding_file', label: 'Embedding File' },
|
||||
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
||||
];
|
||||
} else if (formState.defaultValues?.type === 'ip_adapter') {
|
||||
return [{ value: 'invokeai', label: 'invokeai' }];
|
||||
} else {
|
||||
return [
|
||||
{ value: 'diffusers', label: 'Diffusers' },
|
||||
{ value: 'checkpoint', label: 'Checkpoint' },
|
||||
];
|
||||
}
|
||||
}, [formState.defaultValues?.type]);
|
||||
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
||||
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(ModelFormatSelect);
|
@ -0,0 +1,33 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
import { MODEL_TYPE_LABELS } from '../../ModelManagerPanel/ModelTypeFilter';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
|
||||
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
|
||||
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
|
||||
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
|
||||
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
|
||||
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
|
||||
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
|
||||
];
|
||||
|
||||
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(ModelTypeSelect);
|
@ -0,0 +1,27 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig, CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'normal', label: 'Normal' },
|
||||
{ value: 'inpaint', label: 'Inpaint' },
|
||||
{ value: 'depth', label: 'Depth' },
|
||||
];
|
||||
|
||||
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(ModelVariantSelect);
|
@ -0,0 +1,28 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
{ value: 'epsilon', label: 'epsilon' },
|
||||
{ value: 'v_prediction', label: 'v_prediction' },
|
||||
{ value: 'sample', label: 'sample' },
|
||||
];
|
||||
|
||||
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(PredictionTypeSelect);
|
@ -0,0 +1,30 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'none', label: '-' },
|
||||
{ value: 'fp16', label: 'fp16' },
|
||||
{ value: 'fp32', label: 'fp32' },
|
||||
{ value: 'onnx', label: 'onnx' },
|
||||
{ value: 'openvino', label: 'openvino' },
|
||||
{ value: 'flax', label: 'flax' },
|
||||
];
|
||||
|
||||
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||
const { field } = useController(props);
|
||||
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||
};
|
||||
|
||||
export default typedMemo(RepoVariantSelect);
|
@ -0,0 +1,8 @@
|
||||
import { useAppSelector } from '../../../../app/store/storeHooks';
|
||||
import { ModelEdit } from './ModelEdit';
|
||||
import { ModelView } from './ModelView';
|
||||
|
||||
export const Model = () => {
|
||||
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />;
|
||||
};
|
@ -0,0 +1,196 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks';
|
||||
import { useGetModelQuery } from '../../../../services/api/endpoints/models';
|
||||
import { Flex, Text, Heading, Button, Input, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import {
|
||||
AnyModelConfig,
|
||||
CheckpointModelConfig,
|
||||
ControlNetConfig,
|
||||
DiffusersModelConfig,
|
||||
IPAdapterConfig,
|
||||
LoRAConfig,
|
||||
T2IAdapterConfig,
|
||||
TextualInversionConfig,
|
||||
VAEConfig,
|
||||
} from '../../../../services/api/types';
|
||||
import { setSelectedModelMode } from '../../store/modelManagerV2Slice';
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import ModelTypeSelect from './Fields/ModelTypeSelect';
|
||||
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||
import RepoVariantSelect from './Fields/RepoVariantSelect';
|
||||
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||
import BooleanSelect from './Fields/BooleanSelect';
|
||||
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
||||
|
||||
export const ModelEdit = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const modelData = useMemo(() => {
|
||||
if (!data) {
|
||||
return null;
|
||||
}
|
||||
const modelFormat = data.format;
|
||||
const modelType = data.type;
|
||||
|
||||
if (modelType === 'main') {
|
||||
if (modelFormat === 'diffusers') {
|
||||
return data as DiffusersModelConfig;
|
||||
} else if (modelFormat === 'checkpoint') {
|
||||
return data as CheckpointModelConfig;
|
||||
}
|
||||
}
|
||||
|
||||
switch (modelType) {
|
||||
case 'lora':
|
||||
return data as LoRAConfig;
|
||||
case 'embedding':
|
||||
return data as TextualInversionConfig;
|
||||
case 't2i_adapter':
|
||||
return data as T2IAdapterConfig;
|
||||
case 'ip_adapter':
|
||||
return data as IPAdapterConfig;
|
||||
case 'controlnet':
|
||||
return data as ControlNetConfig;
|
||||
case 'vae':
|
||||
return data as VAEConfig;
|
||||
default:
|
||||
return data as DiffusersModelConfig;
|
||||
}
|
||||
}, [data]);
|
||||
|
||||
const {
|
||||
register,
|
||||
handleSubmit,
|
||||
control,
|
||||
formState: { errors },
|
||||
reset,
|
||||
} = useForm<AnyModelConfig>({
|
||||
defaultValues: {
|
||||
...modelData,
|
||||
},
|
||||
mode: 'onChange',
|
||||
});
|
||||
|
||||
const handleClickCancel = useCallback(() => {
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
}, [dispatch]);
|
||||
|
||||
if (isLoading) {
|
||||
return <Text>Loading</Text>;
|
||||
}
|
||||
|
||||
if (!modelData) {
|
||||
return <Text>Something went wrong</Text>;
|
||||
}
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||
<Input
|
||||
{...register('name', {
|
||||
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||
})}
|
||||
size="lg"
|
||||
/>
|
||||
<Flex gap={2}>
|
||||
<Button size="sm" onClick={handleClickCancel}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button size="sm" colorScheme="invokeYellow">
|
||||
Save
|
||||
</Button>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={3} mt="4">
|
||||
<Flex>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Description</FormLabel>
|
||||
<Textarea fontSize="md" resize="none" {...register('description')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
Model Settings
|
||||
</Heading>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Base Model</FormLabel>
|
||||
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Model Type</FormLabel>
|
||||
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Format</FormLabel>
|
||||
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Path</FormLabel>
|
||||
<Input
|
||||
{...register('path', {
|
||||
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||
})}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Repo Variant</FormLabel>
|
||||
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||
</FormControl>
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Config Path</FormLabel>
|
||||
<Input {...register('config')} />
|
||||
</FormControl>
|
||||
)}
|
||||
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Variant</FormLabel>
|
||||
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Prediction Type</FormLabel>
|
||||
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Upcast Attention</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>ZTSNR Training</FormLabel>
|
||||
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>VAE Path</FormLabel>
|
||||
<Input {...register('vae')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>Image Encoder Model ID</FormLabel>
|
||||
<Input {...register('image_encoder_model_id')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -1,9 +1,9 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from '../../../../app/store/storeHooks';
|
||||
import { useGetModelQuery } from '../../../../services/api/endpoints/models';
|
||||
import { Flex, Text, Heading } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks';
|
||||
import { useGetModelMetadataQuery, useGetModelQuery } from '../../../../services/api/endpoints/models';
|
||||
import { Flex, Text, Heading, Button, Box } from '@invoke-ai/ui-library';
|
||||
import DataViewer from '../../../gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { useMemo } from 'react';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import {
|
||||
CheckpointModelConfig,
|
||||
ControlNetConfig,
|
||||
@ -15,102 +15,128 @@ import {
|
||||
VAEConfig,
|
||||
} from '../../../../services/api/types';
|
||||
import { ModelAttrView } from './ModelAttrView';
|
||||
import { IoPencil } from 'react-icons/io5';
|
||||
import { setSelectedModelMode } from '../../store/modelManagerV2Slice';
|
||||
|
||||
export const ModelView = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
|
||||
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const modelConfigData = useMemo(() => {
|
||||
const modelData = useMemo(() => {
|
||||
if (!data) {
|
||||
return null;
|
||||
}
|
||||
const modelFormat = data.config.format;
|
||||
const modelType = data.config.type;
|
||||
const modelFormat = data.format;
|
||||
const modelType = data.type;
|
||||
|
||||
if (modelType === 'main') {
|
||||
if (modelFormat === 'diffusers') {
|
||||
return data.config as DiffusersModelConfig;
|
||||
return data as DiffusersModelConfig;
|
||||
} else if (modelFormat === 'checkpoint') {
|
||||
return data.config as CheckpointModelConfig;
|
||||
return data as CheckpointModelConfig;
|
||||
}
|
||||
}
|
||||
|
||||
switch (modelType) {
|
||||
case 'lora':
|
||||
return data.config as LoRAConfig;
|
||||
return data as LoRAConfig;
|
||||
case 'embedding':
|
||||
return data.config as TextualInversionConfig;
|
||||
return data as TextualInversionConfig;
|
||||
case 't2i_adapter':
|
||||
return data.config as T2IAdapterConfig;
|
||||
return data as T2IAdapterConfig;
|
||||
case 'ip_adapter':
|
||||
return data.config as IPAdapterConfig;
|
||||
return data as IPAdapterConfig;
|
||||
case 'controlnet':
|
||||
return data.config as ControlNetConfig;
|
||||
return data as ControlNetConfig;
|
||||
case 'vae':
|
||||
return data.config as VAEConfig;
|
||||
return data as VAEConfig;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}, [data]);
|
||||
|
||||
const handleEditModel = useCallback(() => {
|
||||
dispatch(setSelectedModelMode('edit'));
|
||||
}, [dispatch]);
|
||||
|
||||
if (isLoading) {
|
||||
return <Text>Loading</Text>;
|
||||
}
|
||||
|
||||
if (!modelConfigData) {
|
||||
if (!modelData) {
|
||||
return <Text>Something went wrong</Text>;
|
||||
}
|
||||
return (
|
||||
<Flex flexDir="column" h="full">
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{modelConfigData.name}
|
||||
</Heading>
|
||||
{modelConfigData.source && <Text variant="subtext">Source: {modelConfigData.source}</Text>}
|
||||
<Flex w="full" justifyContent="space-between">
|
||||
<Flex flexDir="column" gap={1} p={2}>
|
||||
<Heading as="h2" fontSize="lg">
|
||||
{modelData.name}
|
||||
</Heading>
|
||||
|
||||
{modelData.source && <Text variant="subtext">Source: {modelData.source}</Text>}
|
||||
</Flex>
|
||||
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
|
||||
Edit
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" p={2} gap={3}>
|
||||
<Flex>
|
||||
<ModelAttrView label="Description" value={modelConfigData.description} />
|
||||
<ModelAttrView label="Description" value={modelData.description} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label="Base Model" value={modelConfigData.base} />
|
||||
<ModelAttrView label="Model Type" value={modelConfigData.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label="Format" value={modelConfigData.format} />
|
||||
<ModelAttrView label="Path" value={modelConfigData.path} />
|
||||
</Flex>
|
||||
{modelConfigData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
{modelConfigData.format === 'diffusers' && (
|
||||
<ModelAttrView label="Repo Variant" value={modelConfigData.repo_variant} />
|
||||
)}
|
||||
{modelConfigData.format === 'checkpoint' && (
|
||||
<ModelAttrView label="Config Path" value={modelConfigData.config} />
|
||||
)}
|
||||
|
||||
<ModelAttrView label="Variant" value={modelConfigData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label="Prediction Type" value={modelConfigData.prediction_type} />
|
||||
<ModelAttrView label="Upcast Attention" value={`${modelConfigData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label="ZTSNR Training" value={`${modelConfigData.ztsnr_training}`} />
|
||||
<ModelAttrView label="VAE" value={modelConfigData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelConfigData.type === 'ip_adapter' && (
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
Model Settings
|
||||
</Heading>
|
||||
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label="Image Encoder Model ID" value={modelConfigData.image_encoder_model_id} />
|
||||
<ModelAttrView label="Base Model" value={modelData.base} />
|
||||
<ModelAttrView label="Model Type" value={modelData.type} />
|
||||
</Flex>
|
||||
)}
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label="Format" value={modelData.format} />
|
||||
<ModelAttrView label="Path" value={modelData.path} />
|
||||
</Flex>
|
||||
{modelData.type === 'main' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
{modelData.format === 'diffusers' && (
|
||||
<ModelAttrView label="Repo Variant" value={modelData.repo_variant} />
|
||||
)}
|
||||
{modelData.format === 'checkpoint' && <ModelAttrView label="Config Path" value={modelData.config} />}
|
||||
|
||||
<ModelAttrView label="Variant" value={modelData.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label="Prediction Type" value={modelData.prediction_type} />
|
||||
<ModelAttrView label="Upcast Attention" value={`${modelData.upcast_attention}`} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label="ZTSNR Training" value={`${modelData.ztsnr_training}`} />
|
||||
<ModelAttrView label="VAE" value={modelData.vae} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{modelData.type === 'ip_adapter' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label="Image Encoder Model ID" value={modelData.image_encoder_model_id} />
|
||||
</Flex>
|
||||
)}
|
||||
</Box>
|
||||
</Flex>
|
||||
|
||||
<Flex h="full">{!!data?.metadata && <DataViewer label="metadata" data={data.metadata} />}</Flex>
|
||||
{metadata && (
|
||||
<>
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
Model Metadata
|
||||
</Heading>
|
||||
<Flex h="full" w="full" p={2}>
|
||||
<DataViewer label="metadata" data={metadata} />
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -2,7 +2,7 @@ import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import type { LoRAModelFormat } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Mapping of model type to human readable name
|
||||
* Mapping of base model to human readable name
|
||||
*/
|
||||
export const MODEL_TYPE_MAP = {
|
||||
any: 'Any',
|
||||
@ -13,7 +13,7 @@ export const MODEL_TYPE_MAP = {
|
||||
};
|
||||
|
||||
/**
|
||||
* Mapping of model type to (short) human readable name
|
||||
* Mapping of base model to (short) human readable name
|
||||
*/
|
||||
export const MODEL_TYPE_SHORT_MAP = {
|
||||
any: 'Any',
|
||||
@ -24,7 +24,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
||||
};
|
||||
|
||||
/**
|
||||
* Mapping of model type to CLIP skip parameter constraints
|
||||
* Mapping of base model to CLIP skip parameter constraints
|
||||
*/
|
||||
export const CLIP_SKIP_MAP = {
|
||||
any: {
|
||||
|
@ -29,6 +29,8 @@ type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']
|
||||
|
||||
type GetModelResponse =
|
||||
paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
type GetModelMetadataResponse =
|
||||
paths['/api/v2/models/meta/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||
|
||||
@ -175,6 +177,12 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
providesTags: ['Model'],
|
||||
}),
|
||||
getModelMetadata: build.query<GetModelMetadataResponse, string>({
|
||||
query: (key) => {
|
||||
return buildModelsUrl(`meta/i/${key}`);
|
||||
},
|
||||
providesTags: ['Model'],
|
||||
}),
|
||||
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
||||
query: ({ key, body }) => {
|
||||
return {
|
||||
@ -330,5 +338,6 @@ export const {
|
||||
useGetModelsInFolderQuery,
|
||||
useGetCheckpointConfigsQuery,
|
||||
useGetModelImportsQuery,
|
||||
useGetModelQuery
|
||||
useGetModelQuery,
|
||||
useGetModelMetadataQuery
|
||||
} = modelsApi;
|
||||
|
@ -22,7 +22,7 @@ export type paths = {
|
||||
"/api/v2/models/i/{key}": {
|
||||
/**
|
||||
* Get Model Record
|
||||
* @description Get a model record and metadata
|
||||
* @description Get a model record
|
||||
*/
|
||||
get: operations["get_model_record"];
|
||||
/**
|
||||
@ -4202,13 +4202,6 @@ export type components = {
|
||||
*/
|
||||
type: "freeu";
|
||||
};
|
||||
/** GetModelResponse */
|
||||
GetModelResponse: {
|
||||
/** Config */
|
||||
config: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"];
|
||||
/** Metadata */
|
||||
metadata: (components["schemas"]["BaseMetadata"] | components["schemas"]["HuggingFaceMetadata"] | components["schemas"]["CivitaiMetadata"]) | null;
|
||||
};
|
||||
/** Graph */
|
||||
Graph: {
|
||||
/**
|
||||
@ -11176,7 +11169,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Get Model Record
|
||||
* @description Get a model record and metadata
|
||||
* @description Get a model record
|
||||
*/
|
||||
get_model_record: {
|
||||
parameters: {
|
||||
@ -11189,7 +11182,7 @@ export type operations = {
|
||||
/** @description The model configuration was retrieved successfully */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["GetModelResponse"];
|
||||
"application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"];
|
||||
};
|
||||
};
|
||||
/** @description Bad request */
|
||||
@ -11346,10 +11339,6 @@ export type operations = {
|
||||
400: {
|
||||
content: never;
|
||||
};
|
||||
/** @description No metadata available */
|
||||
404: {
|
||||
content: never;
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
content: {
|
||||
|
Loading…
Reference in New Issue
Block a user