single model view

This commit is contained in:
Mary Hipp 2024-02-21 09:39:02 -05:00 committed by psychedelicious
parent b8b3ef9725
commit 4fd259bb89
8 changed files with 195 additions and 27 deletions

View File

@ -3,10 +3,12 @@ import {
Button,
ConfirmationAlertDialog,
Flex,
Icon,
IconButton,
Text,
Tooltip,
useDisclosure,
Box,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
@ -15,6 +17,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { IoWarning } from 'react-icons/io5';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useDeleteModelsMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
@ -88,8 +91,16 @@ const ModelListItem = (props: ModelListItemProps) => {
<Tooltip label={model.description} placement="bottom">
<Text>{model.name}</Text>
</Tooltip>
{model.format === 'checkpoint' && (
<Tooltip label="Checkpoint">
<Box>
<Icon as={IoWarning} />
</Box>
</Tooltip>
)}
</Flex>
</Flex>
<IconButton
onClick={onOpen}
icon={<PiTrashSimpleBold />}

View File

@ -0,0 +1,13 @@
import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from '../../../app/store/storeHooks';
import { ImportModels } from './ImportModels';
import { ModelView } from './ModelPanel/ModelView';
export const ModelPane = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
return (
<Box layerStyle="first" p={2} borderRadius="base" w="full" h="full">
{selectedModelKey ? <ModelView /> : <ImportModels />}
</Box>
);
};

View File

@ -0,0 +1,15 @@
import { FormControl, FormLabel, Text } from '@invoke-ai/ui-library';
interface Props {
label: string;
value: string | null | undefined;
}
export const ModelAttrView = ({ label, value }: Props) => {
return (
<FormControl flexDir="column" alignItems="flex-start" gap={0}>
<FormLabel>{label}</FormLabel>
<Text fontSize="md">{value || '-'}</Text>
</FormControl>
);
};

View File

@ -0,0 +1,116 @@
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from '../../../../app/store/storeHooks';
import { useGetModelQuery } from '../../../../services/api/endpoints/models';
import { Flex, Text, Heading } from '@invoke-ai/ui-library';
import DataViewer from '../../../gallery/components/ImageMetadataViewer/DataViewer';
import { useMemo } from 'react';
import {
CheckpointModelConfig,
ControlNetConfig,
DiffusersModelConfig,
IPAdapterConfig,
LoRAConfig,
T2IAdapterConfig,
TextualInversionConfig,
VAEConfig,
} from '../../../../services/api/types';
import { ModelAttrView } from './ModelAttrView';
export const ModelView = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelQuery(selectedModelKey ?? skipToken);
const modelConfigData = useMemo(() => {
if (!data) {
return null;
}
const modelFormat = data.config.format;
const modelType = data.config.type;
if (modelType === 'main') {
if (modelFormat === 'diffusers') {
return data.config as DiffusersModelConfig;
} else if (modelFormat === 'checkpoint') {
return data.config as CheckpointModelConfig;
}
}
switch (modelType) {
case 'lora':
return data.config as LoRAConfig;
case 'embedding':
return data.config as TextualInversionConfig;
case 't2i_adapter':
return data.config as T2IAdapterConfig;
case 'ip_adapter':
return data.config as IPAdapterConfig;
case 'controlnet':
return data.config as ControlNetConfig;
case 'vae':
return data.config as VAEConfig;
default:
return null;
}
}, [data]);
if (isLoading) {
return <Text>Loading</Text>;
}
if (!modelConfigData) {
return <Text>Something went wrong</Text>;
}
return (
<Flex flexDir="column" h="full">
<Flex flexDir="column" gap={1} p={2}>
<Heading as="h2" fontSize="lg">
{modelConfigData.name}
</Heading>
{modelConfigData.source && <Text variant="subtext">Source: {modelConfigData.source}</Text>}
</Flex>
<Flex flexDir="column" p={2} gap={3}>
<Flex>
<ModelAttrView label="Description" value={modelConfigData.description} />
</Flex>
<Flex gap={2}>
<ModelAttrView label="Base Model" value={modelConfigData.base} />
<ModelAttrView label="Model Type" value={modelConfigData.type} />
</Flex>
<Flex gap={2}>
<ModelAttrView label="Format" value={modelConfigData.format} />
<ModelAttrView label="Path" value={modelConfigData.path} />
</Flex>
{modelConfigData.type === 'main' && (
<>
<Flex gap={2}>
{modelConfigData.format === 'diffusers' && (
<ModelAttrView label="Repo Variant" value={modelConfigData.repo_variant} />
)}
{modelConfigData.format === 'checkpoint' && (
<ModelAttrView label="Config Path" value={modelConfigData.config} />
)}
<ModelAttrView label="Variant" value={modelConfigData.variant} />
</Flex>
<Flex gap={2}>
<ModelAttrView label="Prediction Type" value={modelConfigData.prediction_type} />
<ModelAttrView label="Upcast Attention" value={`${modelConfigData.upcast_attention}`} />
</Flex>
<Flex gap={2}>
<ModelAttrView label="ZTSNR Training" value={`${modelConfigData.ztsnr_training}`} />
<ModelAttrView label="VAE" value={modelConfigData.vae} />
</Flex>
</>
)}
{modelConfigData.type === 'ip_adapter' && (
<Flex gap={2}>
<ModelAttrView label="Image Encoder Model ID" value={modelConfigData.image_encoder_model_id} />
</Flex>
)}
</Flex>
<Flex h="full">{!!data?.metadata && <DataViewer label="metadata" data={data.metadata} />}</Flex>
</Flex>
);
};

View File

@ -1,14 +0,0 @@
import { z } from "zod";
export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
export const zModelType = z.enum([
'main',
'vae',
'lora',
'controlnet',
'embedding',
'ip_adapter',
'clip_vision',
't2i_adapter',
'onnx', // TODO(psyche): Remove this when removed from backend
]);

View File

@ -1,16 +1,25 @@
import { Box,Flex } from '@invoke-ai/ui-library';
import { ImportModels } from 'features/modelManagerV2/subpanels/ImportModels';
import { ModelManager } from 'features/modelManagerV2/subpanels/ModelManager';
import { memo } from 'react';
import { Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs, Box, Button } from '@invoke-ai/ui-library';
import ImportModelsPanel from 'features/modelManager/subpanels/ImportModelsPanel';
import MergeModelsPanel from 'features/modelManager/subpanels/MergeModelsPanel';
import ModelManagerPanel from 'features/modelManager/subpanels/ModelManagerPanel';
import ModelManagerSettingsPanel from 'features/modelManager/subpanels/ModelManagerSettingsPanel';
import type { ReactNode } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { SyncModelsIconButton } from '../../../modelManager/components/SyncModels/SyncModelsIconButton';
import { ModelManager } from '../../../modelManagerV2/subpanels/ModelManager';
import { ModelPane } from '../../../modelManagerV2/subpanels/ModelPane';
type ModelManagerTabName = 'modelManager' | 'importModels' | 'mergeModels' | 'settings';
const ModelManagerTab = () => {
const { t } = useTranslation();
return (
<Box w="full" h="full">
<Flex w="full" h="full" gap={4}>
<ModelManager />
<ImportModels />
</Flex>
</Box>
<Flex w="full" h="full" gap="2">
<ModelManager />
<ModelPane />
</Flex>
);
};

View File

@ -26,6 +26,10 @@ type UpdateModelArg = {
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type GetModelResponse =
paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
type ListModelsArg = NonNullable<paths['/api/v2/models/']['get']['parameters']['query']>;
type DeleteMainModelArg = {
@ -165,6 +169,12 @@ export const modelsApi = api.injectEndpoints({
providesTags: buildProvidesTags<MainModelConfig>('MainModel'),
transformResponse: buildTransformResponse<MainModelConfig>(mainModelsAdapter),
}),
getModel: build.query<GetModelResponse, string>({
query: (key) => {
return buildModelsUrl(`i/${key}`);
},
providesTags: ['Model'],
}),
updateModels: build.mutation<UpdateModelResponse, UpdateModelArg>({
query: ({ key, body }) => {
return {
@ -320,4 +330,5 @@ export const {
useGetModelsInFolderQuery,
useGetCheckpointConfigsQuery,
useGetModelImportsQuery,
useGetModelQuery
} = modelsApi;

View File

@ -22,7 +22,7 @@ export type paths = {
"/api/v2/models/i/{key}": {
/**
* Get Model Record
* @description Get a model record
* @description Get a model record and metadata
*/
get: operations["get_model_record"];
/**
@ -4202,6 +4202,13 @@ export type components = {
*/
type: "freeu";
};
/** GetModelResponse */
GetModelResponse: {
/** Config */
config: (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"];
/** Metadata */
metadata: (components["schemas"]["BaseMetadata"] | components["schemas"]["HuggingFaceMetadata"] | components["schemas"]["CivitaiMetadata"]) | null;
};
/** Graph */
Graph: {
/**
@ -11169,7 +11176,7 @@ export type operations = {
};
/**
* Get Model Record
* @description Get a model record
* @description Get a model record and metadata
*/
get_model_record: {
parameters: {
@ -11182,7 +11189,7 @@ export type operations = {
/** @description The model configuration was retrieved successfully */
200: {
content: {
"application/json": (components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"]) | (components["schemas"]["ONNXSD1Config"] | components["schemas"]["ONNXSD2Config"] | components["schemas"]["ONNXSDXLConfig"]) | (components["schemas"]["VaeDiffusersConfig"] | components["schemas"]["VaeCheckpointConfig"]) | (components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"]) | components["schemas"]["LoRAConfig"] | components["schemas"]["TextualInversionConfig"] | components["schemas"]["IPAdapterConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["T2IConfig"];
"application/json": components["schemas"]["GetModelResponse"];
};
};
/** @description Bad request */