mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
edit view for model, depending on type and valid values
This commit is contained in:
parent
6b68971f38
commit
0a69779df9
@ -6,6 +6,7 @@ import type { PersistConfig, RootState } from 'app/store/store';
|
|||||||
type ModelManagerState = {
|
type ModelManagerState = {
|
||||||
_version: 1;
|
_version: 1;
|
||||||
selectedModelKey: string | null;
|
selectedModelKey: string | null;
|
||||||
|
selectedModelMode: "edit" | "view",
|
||||||
searchTerm: string;
|
searchTerm: string;
|
||||||
filteredModelType: string | null;
|
filteredModelType: string | null;
|
||||||
};
|
};
|
||||||
@ -13,6 +14,7 @@ type ModelManagerState = {
|
|||||||
export const initialModelManagerState: ModelManagerState = {
|
export const initialModelManagerState: ModelManagerState = {
|
||||||
_version: 1,
|
_version: 1,
|
||||||
selectedModelKey: null,
|
selectedModelKey: null,
|
||||||
|
selectedModelMode: "view",
|
||||||
filteredModelType: null,
|
filteredModelType: null,
|
||||||
searchTerm: ""
|
searchTerm: ""
|
||||||
};
|
};
|
||||||
@ -22,8 +24,12 @@ export const modelManagerV2Slice = createSlice({
|
|||||||
initialState: initialModelManagerState,
|
initialState: initialModelManagerState,
|
||||||
reducers: {
|
reducers: {
|
||||||
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
|
setSelectedModelKey: (state, action: PayloadAction<string | null>) => {
|
||||||
|
state.selectedModelMode = "view"
|
||||||
state.selectedModelKey = action.payload;
|
state.selectedModelKey = action.payload;
|
||||||
},
|
},
|
||||||
|
setSelectedModelMode: (state, action: PayloadAction<"view" | "edit">) => {
|
||||||
|
state.selectedModelMode = action.payload;
|
||||||
|
},
|
||||||
setSearchTerm: (state, action: PayloadAction<string>) => {
|
setSearchTerm: (state, action: PayloadAction<string>) => {
|
||||||
state.searchTerm = action.payload;
|
state.searchTerm = action.payload;
|
||||||
},
|
},
|
||||||
@ -34,7 +40,7 @@ export const modelManagerV2Slice = createSlice({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType } = modelManagerV2Slice.actions;
|
export const { setSelectedModelKey, setSearchTerm, setFilteredModelType, setSelectedModelMode } = modelManagerV2Slice.actions;
|
||||||
|
|
||||||
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
|
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
|
||||||
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
import { Box } from '@invoke-ai/ui-library';
|
import { Box } from '@invoke-ai/ui-library';
|
||||||
import { useAppSelector } from '../../../app/store/storeHooks';
|
import { useAppSelector } from '../../../app/store/storeHooks';
|
||||||
import { ImportModels } from './ImportModels';
|
import { ImportModels } from './ImportModels';
|
||||||
import { ModelView } from './ModelPanel/ModelView';
|
import { Model } from './ModelPanel/Model';
|
||||||
|
|
||||||
export const ModelPane = () => {
|
export const ModelPane = () => {
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
return (
|
return (
|
||||||
<Box layerStyle="first" p={2} borderRadius="base" w="full" h="full">
|
<Box layerStyle="first" p={2} borderRadius="base" w="full" h="full">
|
||||||
{selectedModelKey ? <ModelView /> : <ImportModels />}
|
{selectedModelKey ? <Model /> : <ImportModels />}
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||||
|
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||||
|
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
||||||
|
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||||
|
];
|
||||||
|
|
||||||
|
const BaseModelSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(BaseModelSelect);
|
@ -0,0 +1,27 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'none', label: '-' },
|
||||||
|
{ value: true as any, label: 'True' },
|
||||||
|
{ value: false as any, label: 'False' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const BooleanSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(BooleanSelect);
|
@ -0,0 +1,53 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { LORA_MODEL_FORMAT_MAP, MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||||
|
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||||
|
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
||||||
|
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||||
|
];
|
||||||
|
|
||||||
|
const ModelFormatSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field, formState } = useController(props);
|
||||||
|
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = useMemo(() => {
|
||||||
|
if (formState.defaultValues?.type === 'lora') {
|
||||||
|
return Object.keys(LORA_MODEL_FORMAT_MAP).map((format) => ({
|
||||||
|
value: format,
|
||||||
|
label: LORA_MODEL_FORMAT_MAP[format],
|
||||||
|
})) as ComboboxOption[];
|
||||||
|
} else if (formState.defaultValues?.type === 'embedding') {
|
||||||
|
return [
|
||||||
|
{ value: 'embedding_file', label: 'Embedding File' },
|
||||||
|
{ value: 'embedding_folder', label: 'Embedding Folder' },
|
||||||
|
];
|
||||||
|
} else if (formState.defaultValues?.type === 'ip_adapter') {
|
||||||
|
return [{ value: 'invokeai', label: 'invokeai' }];
|
||||||
|
} else {
|
||||||
|
return [
|
||||||
|
{ value: 'diffusers', label: 'Diffusers' },
|
||||||
|
{ value: 'checkpoint', label: 'Checkpoint' },
|
||||||
|
];
|
||||||
|
}
|
||||||
|
}, [formState.defaultValues?.type]);
|
||||||
|
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [options, field.value]);
|
||||||
|
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(ModelFormatSelect);
|
@ -0,0 +1,33 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
import { MODEL_TYPE_LABELS } from '../../ModelManagerPanel/ModelTypeFilter';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'main', label: MODEL_TYPE_LABELS['main'] as string },
|
||||||
|
{ value: 'lora', label: MODEL_TYPE_LABELS['lora'] as string },
|
||||||
|
{ value: 'embedding', label: MODEL_TYPE_LABELS['embedding'] as string },
|
||||||
|
{ value: 'vae', label: MODEL_TYPE_LABELS['vae'] as string },
|
||||||
|
{ value: 'controlnet', label: MODEL_TYPE_LABELS['controlnet'] as string },
|
||||||
|
{ value: 'ip_adapter', label: MODEL_TYPE_LABELS['ip_adapter'] as string },
|
||||||
|
{ value: 't2i_adapater', label: MODEL_TYPE_LABELS['t2i_adapter'] as string },
|
||||||
|
];
|
||||||
|
|
||||||
|
const ModelTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(ModelTypeSelect);
|
@ -0,0 +1,27 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig, CheckpointModelConfig, DiffusersModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'normal', label: 'Normal' },
|
||||||
|
{ value: 'inpaint', label: 'Inpaint' },
|
||||||
|
{ value: 'depth', label: 'Depth' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const ModelVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(ModelVariantSelect);
|
@ -0,0 +1,28 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'none', label: '-' },
|
||||||
|
{ value: 'epsilon', label: 'epsilon' },
|
||||||
|
{ value: 'v_prediction', label: 'v_prediction' },
|
||||||
|
{ value: 'sample', label: 'sample' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const PredictionTypeSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(PredictionTypeSelect);
|
@ -0,0 +1,30 @@
|
|||||||
|
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
const options: ComboboxOption[] = [
|
||||||
|
{ value: 'none', label: '-' },
|
||||||
|
{ value: 'fp16', label: 'fp16' },
|
||||||
|
{ value: 'fp32', label: 'fp32' },
|
||||||
|
{ value: 'onnx', label: 'onnx' },
|
||||||
|
{ value: 'openvino', label: 'openvino' },
|
||||||
|
{ value: 'flax', label: 'flax' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const RepoVariantSelect = <T extends AnyModelConfig>(props: UseControllerProps<T>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === field.value), [field.value]);
|
||||||
|
const onChange = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
v?.value === 'none' ? field.onChange(undefined) : field.onChange(v?.value);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
return <Combobox value={value} options={options} onChange={onChange} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(RepoVariantSelect);
|
@ -0,0 +1,8 @@
|
|||||||
|
import { useAppSelector } from '../../../../app/store/storeHooks';
|
||||||
|
import { ModelEdit } from './ModelEdit';
|
||||||
|
import { ModelView } from './ModelView';
|
||||||
|
|
||||||
|
export const Model = () => {
|
||||||
|
const selectedModelMode = useAppSelector((s) => s.modelmanagerV2.selectedModelMode);
|
||||||
|
return selectedModelMode === 'view' ? <ModelView /> : <ModelEdit />;
|
||||||
|
};
|
@ -0,0 +1,196 @@
|
|||||||
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
|
import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks';
|
||||||
|
import { useGetModelQuery } from '../../../../services/api/endpoints/models';
|
||||||
|
import { Flex, Text, Heading, Button, Input, FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
|
||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import {
|
||||||
|
AnyModelConfig,
|
||||||
|
CheckpointModelConfig,
|
||||||
|
ControlNetConfig,
|
||||||
|
DiffusersModelConfig,
|
||||||
|
IPAdapterConfig,
|
||||||
|
LoRAConfig,
|
||||||
|
T2IAdapterConfig,
|
||||||
|
TextualInversionConfig,
|
||||||
|
VAEConfig,
|
||||||
|
} from '../../../../services/api/types';
|
||||||
|
import { setSelectedModelMode } from '../../store/modelManagerV2Slice';
|
||||||
|
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||||
|
import { useForm } from 'react-hook-form';
|
||||||
|
import ModelTypeSelect from './Fields/ModelTypeSelect';
|
||||||
|
import ModelVariantSelect from './Fields/ModelVariantSelect';
|
||||||
|
import RepoVariantSelect from './Fields/RepoVariantSelect';
|
||||||
|
import PredictionTypeSelect from './Fields/PredictionTypeSelect';
|
||||||
|
import BooleanSelect from './Fields/BooleanSelect';
|
||||||
|
import ModelFormatSelect from './Fields/ModelFormatSelect';
|
||||||
|
|
||||||
|
export const ModelEdit = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
|
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
|
const modelData = useMemo(() => {
|
||||||
|
if (!data) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const modelFormat = data.format;
|
||||||
|
const modelType = data.type;
|
||||||
|
|
||||||
|
if (modelType === 'main') {
|
||||||
|
if (modelFormat === 'diffusers') {
|
||||||
|
return data as DiffusersModelConfig;
|
||||||
|
} else if (modelFormat === 'checkpoint') {
|
||||||
|
return data as CheckpointModelConfig;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (modelType) {
|
||||||
|
case 'lora':
|
||||||
|
return data as LoRAConfig;
|
||||||
|
case 'embedding':
|
||||||
|
return data as TextualInversionConfig;
|
||||||
|
case 't2i_adapter':
|
||||||
|
return data as T2IAdapterConfig;
|
||||||
|
case 'ip_adapter':
|
||||||
|
return data as IPAdapterConfig;
|
||||||
|
case 'controlnet':
|
||||||
|
return data as ControlNetConfig;
|
||||||
|
case 'vae':
|
||||||
|
return data as VAEConfig;
|
||||||
|
default:
|
||||||
|
return data as DiffusersModelConfig;
|
||||||
|
}
|
||||||
|
}, [data]);
|
||||||
|
|
||||||
|
const {
|
||||||
|
register,
|
||||||
|
handleSubmit,
|
||||||
|
control,
|
||||||
|
formState: { errors },
|
||||||
|
reset,
|
||||||
|
} = useForm<AnyModelConfig>({
|
||||||
|
defaultValues: {
|
||||||
|
...modelData,
|
||||||
|
},
|
||||||
|
mode: 'onChange',
|
||||||
|
});
|
||||||
|
|
||||||
|
const handleClickCancel = useCallback(() => {
|
||||||
|
dispatch(setSelectedModelMode('view'));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return <Text>Loading</Text>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!modelData) {
|
||||||
|
return <Text>Something went wrong</Text>;
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<Flex flexDir="column" h="full">
|
||||||
|
<Flex w="full" justifyContent="space-between" gap={4} alignItems="center">
|
||||||
|
<Input
|
||||||
|
{...register('name', {
|
||||||
|
validate: (value) => value.trim().length > 3 || 'Must be at least 3 characters',
|
||||||
|
})}
|
||||||
|
size="lg"
|
||||||
|
/>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<Button size="sm" onClick={handleClickCancel}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button size="sm" colorScheme="invokeYellow">
|
||||||
|
Save
|
||||||
|
</Button>
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
|
<Flex flexDir="column" gap={3} mt="4">
|
||||||
|
<Flex>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Description</FormLabel>
|
||||||
|
<Textarea fontSize="md" resize="none" {...register('description')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Heading as="h3" fontSize="md" mt="4">
|
||||||
|
Model Settings
|
||||||
|
</Heading>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Base Model</FormLabel>
|
||||||
|
<BaseModelSelect<AnyModelConfig> control={control} name="base" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Model Type</FormLabel>
|
||||||
|
<ModelTypeSelect<AnyModelConfig> control={control} name="type" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Format</FormLabel>
|
||||||
|
<ModelFormatSelect<AnyModelConfig> control={control} name="format" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Path</FormLabel>
|
||||||
|
<Input
|
||||||
|
{...register('path', {
|
||||||
|
validate: (value) => value.trim().length > 0 || 'Must provide a path',
|
||||||
|
})}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
{modelData.type === 'main' && (
|
||||||
|
<>
|
||||||
|
<Flex gap={4}>
|
||||||
|
{modelData.format === 'diffusers' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Repo Variant</FormLabel>
|
||||||
|
<RepoVariantSelect<AnyModelConfig> control={control} name="repo_variant" />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
{modelData.format === 'checkpoint' && (
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Config Path</FormLabel>
|
||||||
|
<Input {...register('config')} />
|
||||||
|
</FormControl>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Variant</FormLabel>
|
||||||
|
<ModelVariantSelect<AnyModelConfig> control={control} name="variant" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Prediction Type</FormLabel>
|
||||||
|
<PredictionTypeSelect<AnyModelConfig> control={control} name="prediction_type" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Upcast Attention</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="upcast_attention" />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>ZTSNR Training</FormLabel>
|
||||||
|
<BooleanSelect<AnyModelConfig> control={control} name="ztsnr_training" />
|
||||||
|
</FormControl>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>VAE Path</FormLabel>
|
||||||
|
<Input {...register('vae')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{modelData.type === 'ip_adapter' && (
|
||||||
|
<Flex gap={4}>
|
||||||
|
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||||
|
<FormLabel>Image Encoder Model ID</FormLabel>
|
||||||
|
<Input {...register('image_encoder_model_id')} />
|
||||||
|
</FormControl>
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Flex>
|
||||||
|
</Flex>
|
||||||
|
);
|
||||||
|
};
|
@ -1,9 +1,9 @@
|
|||||||
import { skipToken } from '@reduxjs/toolkit/query';
|
import { skipToken } from '@reduxjs/toolkit/query';
|
||||||
import { useAppSelector } from '../../../../app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from '../../../../app/store/storeHooks';
|
||||||
import { useGetModelQuery } from '../../../../services/api/endpoints/models';
|
import { useGetModelMetadataQuery, useGetModelQuery } from '../../../../services/api/endpoints/models';
|
||||||
import { Flex, Text, Heading } from '@invoke-ai/ui-library';
|
import { Flex, Text, Heading, Button, Box } from '@invoke-ai/ui-library';
|
||||||
import DataViewer from '../../../gallery/components/ImageMetadataViewer/DataViewer';
|
import DataViewer from '../../../gallery/components/ImageMetadataViewer/DataViewer';
|
||||||
import { useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
import {
|
import {
|
||||||
CheckpointModelConfig,
|
CheckpointModelConfig,
|
||||||
ControlNetConfig,
|
ControlNetConfig,
|
||||||
@ -15,102 +15,128 @@ import {
|
|||||||
VAEConfig,
|
VAEConfig,
|
||||||
} from '../../../../services/api/types';
|
} from '../../../../services/api/types';
|
||||||
import { ModelAttrView } from './ModelAttrView';
|
import { ModelAttrView } from './ModelAttrView';
|
||||||
|
import { IoPencil } from 'react-icons/io5';
|
||||||
|
import { setSelectedModelMode } from '../../store/modelManagerV2Slice';
|
||||||
|
|
||||||
export const ModelView = () => {
|
export const ModelView = () => {
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
|
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
|
||||||
|
const { data: metadata } = useGetModelMetadataQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
const modelConfigData = useMemo(() => {
|
const modelData = useMemo(() => {
|
||||||
if (!data) {
|
if (!data) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
const modelFormat = data.config.format;
|
const modelFormat = data.format;
|
||||||
const modelType = data.config.type;
|
const modelType = data.type;
|
||||||
|
|
||||||
if (modelType === 'main') {
|
if (modelType === 'main') {
|
||||||
if (modelFormat === 'diffusers') {
|
if (modelFormat === 'diffusers') {
|
||||||
return data.config as DiffusersModelConfig;
|
return data as DiffusersModelConfig;
|
||||||
} else if (modelFormat === 'checkpoint') {
|
} else if (modelFormat === 'checkpoint') {
|
||||||
return data.config as CheckpointModelConfig;
|
return data as CheckpointModelConfig;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (modelType) {
|
switch (modelType) {
|
||||||
case 'lora':
|
case 'lora':
|
||||||
return data.config as LoRAConfig;
|
return data as LoRAConfig;
|
||||||
case 'embedding':
|
case 'embedding':
|
||||||
return data.config as TextualInversionConfig;
|
return data as TextualInversionConfig;
|
||||||
case 't2i_adapter':
|
case 't2i_adapter':
|
||||||
return data.config as T2IAdapterConfig;
|
return data as T2IAdapterConfig;
|
||||||
case 'ip_adapter':
|
case 'ip_adapter':
|
||||||
return data.config as IPAdapterConfig;
|
return data as IPAdapterConfig;
|
||||||
case 'controlnet':
|
case 'controlnet':
|
||||||
return data.config as ControlNetConfig;
|
return data as ControlNetConfig;
|
||||||
case 'vae':
|
case 'vae':
|
||||||
return data.config as VAEConfig;
|
return data as VAEConfig;
|
||||||
default:
|
default:
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}, [data]);
|
}, [data]);
|
||||||
|
|
||||||
|
const handleEditModel = useCallback(() => {
|
||||||
|
dispatch(setSelectedModelMode('edit'));
|
||||||
|
}, [dispatch]);
|
||||||
|
|
||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return <Text>Loading</Text>;
|
return <Text>Loading</Text>;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!modelConfigData) {
|
if (!modelData) {
|
||||||
return <Text>Something went wrong</Text>;
|
return <Text>Something went wrong</Text>;
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" h="full">
|
<Flex flexDir="column" h="full">
|
||||||
<Flex flexDir="column" gap={1} p={2}>
|
<Flex w="full" justifyContent="space-between">
|
||||||
<Heading as="h2" fontSize="lg">
|
<Flex flexDir="column" gap={1} p={2}>
|
||||||
{modelConfigData.name}
|
<Heading as="h2" fontSize="lg">
|
||||||
</Heading>
|
{modelData.name}
|
||||||
{modelConfigData.source && <Text variant="subtext">Source: {modelConfigData.source}</Text>}
|
</Heading>
|
||||||
|
|
||||||
|
{modelData.source && <Text variant="subtext">Source: {modelData.source}</Text>}
|
||||||
|
</Flex>
|
||||||
|
<Button size="sm" leftIcon={<IoPencil />} colorScheme="invokeYellow" onClick={handleEditModel}>
|
||||||
|
Edit
|
||||||
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Flex flexDir="column" p={2} gap={3}>
|
<Flex flexDir="column" p={2} gap={3}>
|
||||||
<Flex>
|
<Flex>
|
||||||
<ModelAttrView label="Description" value={modelConfigData.description} />
|
<ModelAttrView label="Description" value={modelData.description} />
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex gap={2}>
|
<Heading as="h3" fontSize="md" mt="4">
|
||||||
<ModelAttrView label="Base Model" value={modelConfigData.base} />
|
Model Settings
|
||||||
<ModelAttrView label="Model Type" value={modelConfigData.type} />
|
</Heading>
|
||||||
</Flex>
|
<Box layerStyle="second" borderRadius="base" p={3}>
|
||||||
<Flex gap={2}>
|
|
||||||
<ModelAttrView label="Format" value={modelConfigData.format} />
|
|
||||||
<ModelAttrView label="Path" value={modelConfigData.path} />
|
|
||||||
</Flex>
|
|
||||||
{modelConfigData.type === 'main' && (
|
|
||||||
<>
|
|
||||||
<Flex gap={2}>
|
|
||||||
{modelConfigData.format === 'diffusers' && (
|
|
||||||
<ModelAttrView label="Repo Variant" value={modelConfigData.repo_variant} />
|
|
||||||
)}
|
|
||||||
{modelConfigData.format === 'checkpoint' && (
|
|
||||||
<ModelAttrView label="Config Path" value={modelConfigData.config} />
|
|
||||||
)}
|
|
||||||
|
|
||||||
<ModelAttrView label="Variant" value={modelConfigData.variant} />
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<ModelAttrView label="Prediction Type" value={modelConfigData.prediction_type} />
|
|
||||||
<ModelAttrView label="Upcast Attention" value={`${modelConfigData.upcast_attention}`} />
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<ModelAttrView label="ZTSNR Training" value={`${modelConfigData.ztsnr_training}`} />
|
|
||||||
<ModelAttrView label="VAE" value={modelConfigData.vae} />
|
|
||||||
</Flex>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{modelConfigData.type === 'ip_adapter' && (
|
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<ModelAttrView label="Image Encoder Model ID" value={modelConfigData.image_encoder_model_id} />
|
<ModelAttrView label="Base Model" value={modelData.base} />
|
||||||
|
<ModelAttrView label="Model Type" value={modelData.type} />
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
<Flex gap={2}>
|
||||||
|
<ModelAttrView label="Format" value={modelData.format} />
|
||||||
|
<ModelAttrView label="Path" value={modelData.path} />
|
||||||
|
</Flex>
|
||||||
|
{modelData.type === 'main' && (
|
||||||
|
<>
|
||||||
|
<Flex gap={2}>
|
||||||
|
{modelData.format === 'diffusers' && (
|
||||||
|
<ModelAttrView label="Repo Variant" value={modelData.repo_variant} />
|
||||||
|
)}
|
||||||
|
{modelData.format === 'checkpoint' && <ModelAttrView label="Config Path" value={modelData.config} />}
|
||||||
|
|
||||||
|
<ModelAttrView label="Variant" value={modelData.variant} />
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<ModelAttrView label="Prediction Type" value={modelData.prediction_type} />
|
||||||
|
<ModelAttrView label="Upcast Attention" value={`${modelData.upcast_attention}`} />
|
||||||
|
</Flex>
|
||||||
|
<Flex gap={2}>
|
||||||
|
<ModelAttrView label="ZTSNR Training" value={`${modelData.ztsnr_training}`} />
|
||||||
|
<ModelAttrView label="VAE" value={modelData.vae} />
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{modelData.type === 'ip_adapter' && (
|
||||||
|
<Flex gap={2}>
|
||||||
|
<ModelAttrView label="Image Encoder Model ID" value={modelData.image_encoder_model_id} />
|
||||||
|
</Flex>
|
||||||
|
)}
|
||||||
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Flex h="full">{!!data?.metadata && <DataViewer label="metadata" data={data.metadata} />}</Flex>
|
{metadata && (
|
||||||
|
<>
|
||||||
|
<Heading as="h3" fontSize="md" mt="4">
|
||||||
|
Model Metadata
|
||||||
|
</Heading>
|
||||||
|
<Flex h="full" w="full" p={2}>
|
||||||
|
<DataViewer label="metadata" data={metadata} />
|
||||||
|
</Flex>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -2,7 +2,7 @@ import type { ComboboxOption } from '@invoke-ai/ui-library';
|
|||||||
import type { LoRAModelFormat } from 'services/api/types';
|
import type { LoRAModelFormat } from 'services/api/types';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mapping of model type to human readable name
|
* Mapping of base model to human readable name
|
||||||
*/
|
*/
|
||||||
export const MODEL_TYPE_MAP = {
|
export const MODEL_TYPE_MAP = {
|
||||||
any: 'Any',
|
any: 'Any',
|
||||||
@ -13,7 +13,7 @@ export const MODEL_TYPE_MAP = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mapping of model type to (short) human readable name
|
* Mapping of base model to (short) human readable name
|
||||||
*/
|
*/
|
||||||
export const MODEL_TYPE_SHORT_MAP = {
|
export const MODEL_TYPE_SHORT_MAP = {
|
||||||
any: 'Any',
|
any: 'Any',
|
||||||
@ -24,7 +24,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mapping of model type to CLIP skip parameter constraints
|
* Mapping of base model to CLIP skip parameter constraints
|
||||||
*/
|
*/
|
||||||
export const CLIP_SKIP_MAP = {
|
export const CLIP_SKIP_MAP = {
|
||||||
any: {
|
any: {
|
||||||
|
@ -29,6 +29,8 @@ type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']
|
|||||||
|
|
||||||
type GetModelResponse =
|
type GetModelResponse =
|
||||||
paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||||
|
type GetModelMetadataResponse =
|
||||||
|
paths['/api/v2/models/meta/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
|
||||||
|
|
||||||
@ -175,6 +177,12 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
providesTags: ['Model'],
|
providesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
|
getModelMetadata: build.query<GetModelMetadataResponse, string>({
|
||||||
|
query: (key) => {
|
||||||
|
return buildModelsUrl(`meta/i/${key}`);
|
||||||
|
},
|
||||||
|
providesTags: ['Model'],
|
||||||
|
}),
|
||||||
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
|
||||||
query: ({ key, body }) => {
|
query: ({ key, body }) => {
|
||||||
return {
|
return {
|
||||||
@ -330,5 +338,6 @@ export const {
|
|||||||
useGetModelsInFolderQuery,
|
useGetModelsInFolderQuery,
|
||||||
useGetCheckpointConfigsQuery,
|
useGetCheckpointConfigsQuery,
|
||||||
useGetModelImportsQuery,
|
useGetModelImportsQuery,
|
||||||
useGetModelQuery
|
useGetModelQuery,
|
||||||
|
useGetModelMetadataQuery
|
||||||
} = modelsApi;
|
} = modelsApi;
|
||||||
|
@ -22,7 +22,7 @@ export type paths = {
|
|||||||
"/api/v2/models/i/{key}": {
|
"/api/v2/models/i/{key}": {
|
||||||
/**
|
/**
|
||||||
* Get Model Record
|
* Get Model Record
|
||||||
* @description Get a model record and metadata
|
* @description Get a model record
|
||||||
*/
|
*/
|
||||||
get: operations["get_model_record"];
|
get: operations["get_model_record"];
|
||||||
/**
|
/**
|
||||||
@ -4202,13 +4202,6 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
type: "freeu";
|
type: "freeu";
|
||||||
};
|
};
|
||||||
/** GetModelResponse */
|
|
||||||
GetModelResponse: {
|
|
||||||
/** Config */
|
|
||||||
config: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"];
|
|
||||||
/** Metadata */
|
|
||||||
metadata: (components["schemas"]["BaseMetadata"] | components["schemas"]["HuggingFaceMetadata"] | components["schemas"]["CivitaiMetadata"]) | null;
|
|
||||||
};
|
|
||||||
/** Graph */
|
/** Graph */
|
||||||
Graph: {
|
Graph: {
|
||||||
/**
|
/**
|
||||||
@ -11176,7 +11169,7 @@ export type operations = {
|
|||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Get Model Record
|
* Get Model Record
|
||||||
* @description Get a model record and metadata
|
* @description Get a model record
|
||||||
*/
|
*/
|
||||||
get_model_record: {
|
get_model_record: {
|
||||||
parameters: {
|
parameters: {
|
||||||
@ -11189,7 +11182,7 @@ export type operations = {
|
|||||||
/** @description The model configuration was retrieved successfully */
|
/** @description The model configuration was retrieved successfully */
|
||||||
200: {
|
200: {
|
||||||
content: {
|
content: {
|
||||||
"application/json": components["schemas"]["GetModelResponse"];
|
"application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"];
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
/** @description Bad request */
|
/** @description Bad request */
|
||||||
@ -11346,10 +11339,6 @@ export type operations = {
|
|||||||
400: {
|
400: {
|
||||||
content: never;
|
content: never;
|
||||||
};
|
};
|
||||||
/** @description No metadata available */
|
|
||||||
404: {
|
|
||||||
content: never;
|
|
||||||
};
|
|
||||||
/** @description Validation Error */
|
/** @description Validation Error */
|
||||||
422: {
|
422: {
|
||||||
content: {
|
content: {
|
||||||
|
Loading…
Reference in New Issue
Block a user