Merge branch 'main' into lstein/model-manager-route-enhancements

This commit is contained in:
Lincoln Stein 2023-07-14 13:52:55 -04:00 committed by GitHub
commit e71ce83e9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
51 changed files with 1273 additions and 983 deletions

View File

@ -5,6 +5,7 @@ from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field
from typing import Union
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
@ -398,8 +399,8 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on

View File

@ -614,6 +614,7 @@ class ModelManager(object):
rmtree(str(model_path))
else:
model_path.unlink()
self.commit()
# LS: tested
def add_model(

View File

@ -343,6 +343,7 @@
"safetensorModels": "SafeTensors",
"modelAdded": "Model Added",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New",
@ -397,8 +398,8 @@
"delete": "Delete",
"deleteModel": "Delete Model",
"deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
"formMessageDiffusersVAELocation": "VAE Location",
@ -409,7 +410,7 @@
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"v1": "v1",
@ -420,12 +421,14 @@
"pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting",
"modelConverted": "Model Converted",
"modelConversionFailed": "Model Conversion Failed",
"sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom",
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
@ -446,7 +449,8 @@
"weightedSum": "Weighted Sum",
"none": "none",
"addDifference": "Add Difference",
"pickModelType": "Pick Model Type"
"pickModelType": "Pick Model Type",
"selectModel": "Select Model"
},
"parameters": {
"general": "General",
@ -599,7 +603,6 @@
"nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
},
"tooltip": {
"feature": {

View File

@ -9,9 +9,9 @@ import { theme as invokeAITheme } from 'theme/theme';
import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core';
import { mantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css';
import { useMantineTheme } from 'mantine-theme/theme';
type ThemeLocaleProviderProps = {
children: ReactNode;
@ -35,8 +35,10 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
document.body.dir = direction;
}, [direction]);
const mantineTheme = useMantineTheme();
return (
<MantineProvider withGlobalStyles theme={mantineTheme}>
<MantineProvider theme={mantineTheme}>
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>

View File

@ -1,6 +1,7 @@
import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { modelsApi } from 'services/api/endpoints/models';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
@ -23,6 +24,13 @@ export const addSocketConnectedEventListener = () => {
// pass along the socket event as an application action
dispatch(appSocketConnected(action.payload));
// update all server state
dispatch(modelsApi.endpoints.getMainModels.initiate());
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
dispatch(modelsApi.endpoints.getVaeModels.initiate());
},
});
};

View File

@ -100,10 +100,11 @@ export const store = configureStore({
// manually type state, cannot type the arg
// const typedState = state as ReturnType<typeof rootReducer>;
if (action.type.startsWith('api/')) {
// don't log api actions, with manual cache updates they are extremely noisy
return false;
}
// TODO: doing this breaks the rtk query devtools, commenting out for now
// if (action.type.startsWith('api/')) {
// // don't log api actions, with manual cache updates they are extremely noisy
// return false;
// }
if (actionsDenylist.includes(action.type)) {
// don't log other noisy actions

View File

@ -1,11 +1,13 @@
import { Flex } from '@chakra-ui/react';
import { Flex, useColorMode } from '@chakra-ui/react';
import { ReactElement } from 'react';
import { mode } from 'theme/util/mode';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
const { colorMode } = useColorMode();
return (
<Flex
sx={{
@ -14,7 +16,7 @@ export function IAIFormItemWrapper({
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
bg: mode('base.100', 'base.900')(colorMode),
}}
>
{children}

View File

@ -0,0 +1,44 @@
import { useColorMode } from '@chakra-ui/react';
import { TextInput, TextInputProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { mode } from 'theme/util/mode';
type IAIMantineTextInputProps = TextInputProps;
export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
const { ...rest } = props;
const {
base50,
base100,
base200,
base300,
base800,
base700,
base900,
accent500,
accent300,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
return (
<TextInput
styles={() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
})}
{...rest}
/>
);
}

View File

@ -1,10 +1,9 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string;
@ -14,25 +13,6 @@ type IAIMultiSelectProps = MultiSelectProps & {
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props;
const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const { colorMode } = useColorMode();
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
@ -52,6 +32,8 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
[dispatch]
);
const styles = useMantineMultiSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect
@ -60,92 +42,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={() => ({
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
searchInput: {
':placeholder': {
color: mode(base300, base700)(colorMode),
},
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
styles={styles}
{...rest}
/>
</Tooltip>

View File

@ -0,0 +1,78 @@
import { Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
export type IAISelectDataType = {
value: string;
label: string;
tooltip?: string;
};
type IAISelectProps = SelectProps & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={styles}
{...rest}
/>
</Tooltip>
);
};
export default memo(IAIMantineSearchableSelect);

View File

@ -1,10 +1,7 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
import { mode } from 'theme/util/mode';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react';
export type IAISelectDataType = {
value: string;
@ -18,157 +15,13 @@ type IAISelectProps = SelectProps & {
};
const IAIMantineSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { tooltip, inputRef, ...rest } = props;
const { colorMode } = useColorMode();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={() => ({
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base100, base900)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
{...rest}
/>
<Select ref={inputRef} styles={styles} {...rest} />
</Tooltip>
);
};

View File

@ -24,7 +24,7 @@ import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash-es';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
canvasCopiedToClipboard,
canvasDownloadedAsImage,
@ -213,7 +213,7 @@ const IAICanvasToolbar = () => {
}}
>
<Box w={24}>
<IAIMantineSelect
<IAIMantineSearchableSelect
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
value={layer}
data={LAYER_NAMES_DICT}

View File

@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, {
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
} from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
@ -48,7 +48,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
data={controlNetModels}
value={model}
onChange={handleModelChanged}

View File

@ -1,8 +1,12 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, {
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
} from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
@ -11,10 +15,6 @@ import {
ControlNetProcessorNode,
ControlNetProcessorType,
} from '../../store/types';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { createSelector } from '@reduxjs/toolkit';
import { configSelector } from 'features/system/store/configSelectors';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
@ -72,7 +72,7 @@ const ParamControlNetProcessorSelect = (
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label="Processor"
value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors}

View File

@ -9,7 +9,7 @@ import {
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
@ -106,7 +106,7 @@ const ParamEmbeddingPopover = (props: Props) => {
</Text>
</Flex>
) : (
<IAIMantineSelect
<IAIMantineSearchableSelect
inputRef={inputRef}
autoFocus
placeholder={'Add Embedding'}

View File

@ -12,10 +12,10 @@ import {
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { memo, useContext, useRef, useState } from 'react';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
const UpdateImageBoardModal = () => {
// const boards = useSelector(selectBoardsAll);
@ -56,7 +56,7 @@ const UpdateImageBoardModal = () => {
{isFetching ? (
<Spinner />
) : (
<IAIMantineSelect
<IAIMantineSearchableSelect
placeholder="Select Board"
onChange={(v) => setSelectedBoard(v)}
value={selectedBoard}

View File

@ -4,7 +4,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { RootState, stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { loraAdded } from 'features/lora/store/loraSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
@ -84,7 +84,7 @@ const ParamLoRASelect = () => {
}
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
value={null}
data={data}

View File

@ -3,7 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { map } from 'lodash-es';
import { forwardRef, useCallback } from 'react';
import 'reactflow/dist/style.css';
@ -77,7 +77,7 @@ const AddNodeMenu = () => {
return (
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAIMantineSelect
<IAIMantineSearchableSelect
selectOnBlur={false}
placeholder="Add Node"
value={null}

View File

@ -1,7 +1,7 @@
import { Flex, Text } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
@ -89,7 +89,7 @@ const LoRAModelInputFieldComponent = (
}
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
value={selectedLoRAModel?.id ?? null}
label={
selectedLoRAModel?.base_model &&

View File

@ -6,7 +6,7 @@ import {
} from 'features/nodes/types/types';
import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { forEach } from 'lodash-es';
@ -81,14 +81,14 @@ const ModelInputFieldComponent = (
);
return isLoading ? (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<IAIMantineSelect
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]

View File

@ -1,6 +1,6 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
@ -86,7 +86,7 @@ const VaeModelInputFieldComponent = (
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice';
import {
@ -35,7 +35,7 @@ const ParamScaleBeforeProcessing = () => {
};
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('parameters.scaleBeforeProcessing')}
data={BOUNDING_BOX_SCALES_DICT}
value={boundingBoxScale}

View File

@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
@ -48,7 +48,7 @@ const ParamScheduler = () => {
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('parameters.scheduler')}
value={scheduler}
data={data}

View File

@ -1,7 +1,7 @@
import { FACETOOL_TYPES } from 'app/constants';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
FacetoolType,
setFacetoolType,
@ -20,7 +20,7 @@ export default function FaceRestoreType() {
dispatch(setFacetoolType(v as FacetoolType));
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('parameters.type')}
data={FACETOOL_TYPES.concat()}
value={facetoolType}

View File

@ -2,7 +2,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
@ -77,14 +77,14 @@ const ParamMainModelSelect = () => {
);
return isLoading ? (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<IAIMantineSelect
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={t('modelManager.model')}
value={selectedModel?.id}

View File

@ -1,7 +1,7 @@
import { UPSCALING_LEVELS } from 'app/constants';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
UpscalingLevel,
setUpscalingLevel,
@ -24,7 +24,7 @@ export default function UpscaleScale() {
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
disabled={!isESRGANAvailable}
label={t('parameters.scale')}
value={String(upscalingLevel)}

View File

@ -2,7 +2,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es';
@ -92,7 +92,7 @@ const ParamVAEModelSelect = () => {
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')}

View File

@ -13,7 +13,6 @@ export const systemPersistDenylist: (keyof SystemState)[] = [
'isProcessing',
'totalIterations',
'totalSteps',
'openModel',
'isCancelScheduled',
'progressImage',
'wereModelsReceived',

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { reduce, pickBy } from 'lodash-es';
import { pickBy, reduce } from 'lodash-es';
export const systemSelector = (state: RootState) => state.system;
@ -50,3 +50,8 @@ export const languageSelector = createSelector(
export const isProcessingSelector = (state: RootState) =>
state.system.isProcessing;
export const selectIsBusy = createSelector(
(state: RootState) => state,
(state) => state.system.isProcessing || !state.system.isConnected
);

View File

@ -46,7 +46,6 @@ export interface SystemState {
toastQueue: UseToastOptions[];
searchFolder: string | null;
foundModels: InvokeAI.FoundModel[] | null;
openModel: string | null;
/**
* The current progress image
*/
@ -109,7 +108,6 @@ export const initialSystemState: SystemState = {
toastQueue: [],
searchFolder: null,
foundModels: null,
openModel: null,
progressImage: null,
shouldAntialiasProgressImage: false,
sessionId: null,
@ -164,9 +162,6 @@ export const systemSlice = createSlice({
) => {
state.foundModels = action.payload;
},
setOpenModel: (state, action: PayloadAction<string | null>) => {
state.openModel = action.payload;
},
/**
* A cancel was scheduled
*/
@ -433,7 +428,6 @@ export const {
clearToastQueue,
setSearchFolder,
setFoundModels,
setOpenModel,
cancelScheduled,
scheduledCancelAborted,
cancelTypeChanged,

View File

@ -13,7 +13,7 @@ type ModelManagerTabInfo = {
content: ReactNode;
};
const modelManagerTabs: ModelManagerTabInfo[] = [
const tabs: ModelManagerTabInfo[] = [
{
id: 'modelManager',
label: i18n.t('modelManager.modelManager'),
@ -31,49 +31,28 @@ const modelManagerTabs: ModelManagerTabInfo[] = [
},
];
const renderTabsList = () => {
const modelManagerTabListsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabListsToRender.push(
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
);
});
return (
<TabList
sx={{
w: '100%',
color: 'base.200',
flexDirection: 'row',
borderBottomWidth: 2,
borderColor: 'accent.700',
}}
>
{modelManagerTabListsToRender}
</TabList>
);
};
const renderTabPanels = () => {
const modelManagerTabPanelsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabPanelsToRender.push(
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
);
});
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
};
const ModelManagerTab = () => {
return (
<Tabs
isLazy
variant="invokeAI"
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }}
variant="line"
layerStyle="first"
sx={{ w: 'full', h: 'full', p: 4, gap: 4, borderRadius: 'base' }}
>
{renderTabsList()}
{renderTabPanels()}
<TabList>
{tabs.map((tab) => (
<Tab sx={{ borderTopRadius: 'base' }} key={tab.id}>
{tab.label}
</Tab>
))}
</TabList>
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}>
{tabs.map((tab) => (
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
{tab.content}
</TabPanel>
))}
</TabPanels>
</Tabs>
);
};

View File

@ -1,4 +1,4 @@
import { Divider, Flex } from '@chakra-ui/react';
import { Divider, Flex, useColorMode } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
@ -12,6 +12,8 @@ export default function AddModelsPanel() {
(state: RootState) => state.ui.addNewModelUIOption
);
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -20,27 +22,13 @@ export default function AddModelsPanel() {
<Flex columnGap={4}>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
sx={{
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600',
},
}}
isChecked={addNewModelUIOption == 'ckpt'}
>
{t('modelManager.addCheckpointModel')}
</IAIButton>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
sx={{
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600',
},
}}
isChecked={addNewModelUIOption == 'diffusers'}
>
{t('modelManager.addDiffuserModel')}
</IAIButton>

View File

@ -66,13 +66,13 @@ export default function AddDiffusersModel() {
};
return (
<Flex overflow="scroll" maxHeight={window.innerHeight - 270}>
<Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit} w="full">
<VStack rowGap={2}>
<IAIFormItemWrapper>
{/* Name */}
@ -90,7 +90,6 @@ export default function AddDiffusersModel() {
name="name"
type="text"
validate={baseValidation}
width="2xl"
isRequired
/>
{!!errors.name && touched.name ? (
@ -119,7 +118,6 @@ export default function AddDiffusersModel() {
id="description"
name="description"
type="text"
width="2xl"
isRequired
/>
{!!errors.description && touched.description ? (
@ -153,13 +151,7 @@ export default function AddDiffusersModel() {
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="path"
name="path"
type="text"
width="2xl"
/>
<Field as={IAIInput} id="path" name="path" type="text" />
{!!errors.path && touched.path ? (
<FormErrorMessage>{errors.path}</FormErrorMessage>
) : (
@ -181,7 +173,6 @@ export default function AddDiffusersModel() {
id="repo_id"
name="repo_id"
type="text"
width="2xl"
/>
{!!errors.repo_id && touched.repo_id ? (
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
@ -220,7 +211,6 @@ export default function AddDiffusersModel() {
id="vae.path"
name="vae.path"
type="text"
width="2xl"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
@ -245,7 +235,6 @@ export default function AddDiffusersModel() {
id="vae.repo_id"
name="vae.repo_id"
type="text"
width="2xl"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>

View File

@ -1,42 +1,85 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
Flex,
Radio,
RadioGroup,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice';
import { pickBy } from 'lodash-es';
import { useState } from 'react';
import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import {
useGetMainModelsQuery,
useMergeMainModelsMutation,
} from 'services/api/endpoints/models';
import { BaseModelType, MergeModelConfig } from 'services/api/types';
import { mode } from 'theme/util/mode';
const baseModelTypeSelectData = [
{ label: 'Stable Diffusion 1', value: 'sd-1' },
{ label: 'Stable Diffusion 2', value: 'sd-2' },
];
type MergeInterpolationMethods =
| 'weighted_sum'
| 'sigmoid'
| 'inv_sigmoid'
| 'add_difference';
export default function MergeModelsPanel() {
const { t } = useTranslation();
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery();
const diffusersModels = pickBy(
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
const sd1DiffusersModels = pickBy(
data?.entities,
(value, _) => value?.model_format === 'diffusers'
(value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
);
const [modelOne, setModelOne] = useState<string>(
Object.keys(diffusersModels)[0]
const sd2DiffusersModels = pickBy(
data?.entities,
(value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
);
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1]
const modelsMap = useMemo(() => {
return {
'sd-1': sd1DiffusersModels,
'sd-2': sd2DiffusersModels,
};
}, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel])[0]
);
const [modelThree, setModelThree] = useState<string>('none');
const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel])[1]
);
const [modelThree, setModelThree] = useState<string | null>(null);
const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState<
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
>('weighted_sum');
const [modelMergeInterp, setModelMergeInterp] =
useState<MergeInterpolationMethods>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
'root' | 'custom'
@ -47,41 +90,73 @@ export default function MergeModelsPanel() {
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter(
(model) => model !== modelTwo && model !== modelThree
const modelOneList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelTwo && model !== modelThree);
const modelTwoList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelThree);
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
(model) => model !== modelOne && model !== modelTwo
);
const modelTwoList = Object.keys(diffusersModels).filter(
(model) => model !== modelOne && model !== modelThree
);
const modelThreeList = [
{ key: t('modelManager.none'), value: 'none' },
...Object.keys(diffusersModels)
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ key: model, value: model })),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const handleBaseModelChange = (v: string) => {
setBaseModel(v as BaseModelType);
setModelOne(null);
setModelTwo(null);
};
const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const models_names: string[] = [];
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
models_to_merge: modelsToMerge,
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== null);
modelsToMerge.forEach((model) => {
if (model) {
models_names.push(model?.split('/')[2]);
}
});
const mergeModelsInfo: MergeModelConfig = {
model_names: models_names,
merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
// model_merge_save_path:
// modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce,
};
dispatch(mergeDiffusersModels(mergeModelsInfo));
mergeModels({
base_model: baseModel,
body: mergeModelsInfo,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMerged'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMergeFailed'),
status: 'error',
})
)
);
}
});
};
return (
@ -90,7 +165,6 @@ export default function MergeModelsPanel() {
sx={{
flexDirection: 'column',
rowGap: 1,
bg: 'base.900',
}}
>
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
@ -98,26 +172,43 @@ export default function MergeModelsPanel() {
{t('modelManager.modelMergeHeaderHelp2')}
</Text>
</Flex>
<Flex columnGap={4}>
<IAISelect
<IAIMantineSelect
label="Model Type"
w="100%"
data={baseModelTypeSelectData}
value={baseModel}
onChange={handleBaseModelChange}
/>
<IAIMantineSearchableSelect
label={t('modelManager.modelOne')}
validValues={modelOneList}
onChange={(e) => setModelOne(e.target.value)}
w="100%"
value={modelOne}
placeholder={t('modelManager.selectModel')}
data={modelOneList}
onChange={(v) => setModelOne(v)}
/>
<IAISelect
<IAIMantineSearchableSelect
label={t('modelManager.modelTwo')}
validValues={modelTwoList}
onChange={(e) => setModelTwo(e.target.value)}
w="100%"
placeholder={t('modelManager.selectModel')}
value={modelTwo}
data={modelTwoList}
onChange={(v) => setModelTwo(v)}
/>
<IAISelect
<IAIMantineSearchableSelect
label={t('modelManager.modelThree')}
validValues={modelThreeList}
onChange={(e) => {
if (e.target.value !== 'none') {
setModelThree(e.target.value);
data={modelThreeList}
w="100%"
placeholder={t('modelManager.selectModel')}
clearable
onChange={(v) => {
if (!v) {
setModelThree(null);
setModelMergeInterp('add_difference');
} else {
setModelThree('none');
setModelThree(v);
setModelMergeInterp('weighted_sum');
}
}}
@ -136,7 +227,7 @@ export default function MergeModelsPanel() {
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
bg: mode('base.100', 'base.800')(colorMode),
}}
>
<IAISlider
@ -161,7 +252,7 @@ export default function MergeModelsPanel() {
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
bg: mode('base.100', 'base.800')(colorMode),
}}
>
<Text fontWeight={500} fontSize="sm" variant="subtext">
@ -169,12 +260,10 @@ export default function MergeModelsPanel() {
</Text>
<RadioGroup
value={modelMergeInterp}
onChange={(
v: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
) => setModelMergeInterp(v)}
onChange={(v: MergeInterpolationMethods) => setModelMergeInterp(v)}
>
<Flex columnGap={4}>
{modelThree === 'none' ? (
{modelThree === null ? (
<>
<Radio value="weighted_sum">
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
@ -199,7 +288,7 @@ export default function MergeModelsPanel() {
</RadioGroup>
</Flex>
<Flex
{/* <Flex
sx={{
flexDirection: 'column',
padding: 4,
@ -235,7 +324,7 @@ export default function MergeModelsPanel() {
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/>
)}
</Flex>
</Flex> */}
<IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')}
@ -246,10 +335,8 @@ export default function MergeModelsPanel() {
<IAIButton
onClick={mergeModelsHandler}
isLoading={isProcessing}
isDisabled={
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
}
isLoading={isLoading}
isDisabled={modelOne === null || modelTwo === null}
>
{t('modelManager.merge')}
</IAIButton>

View File

@ -1,44 +1,60 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { Flex, Text } from '@chakra-ui/react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useState } from 'react';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() {
const { data: mainModels } = useGetMainModelsQuery();
const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const renderModelEditTabs = () => {
if (!openModel || !mainModels) return;
if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
return (
<DiffusersModelEdit
modelToEdit={openModel}
retrievedModel={mainModels['entities'][openModel]}
key={openModel}
/>
);
} else {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={mainModels['entities'][openModel]}
key={openModel}
/>
);
}
};
return (
<Flex width="100%" columnGap={8}>
<ModelList />
{renderModelEditTabs()}
<Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
<ModelList
selectedModelId={selectedModelId}
setSelectedModelId={setSelectedModelId}
/>
<ModelEdit model={model} />
</Flex>
);
}
type ModelEditProps = {
model: MainModelConfigEntity | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
const { model } = props;
if (model?.model_format === 'checkpoint') {
return <CheckpointModelEdit key={model.id} model={model} />;
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model} />;
}
return (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
maxH: 96,
userSelect: 'none',
}}
>
<Text variant="subtext">No Model Selected</Text>
</Flex>
);
};

View File

@ -1,17 +1,20 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useForm } from '@mantine/form';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
CheckpointModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert';
const baseModelSelectData = [
@ -25,55 +28,92 @@ const variantSelectData = [
{ value: 'depth', label: 'Depth' },
];
export type CheckpointModel =
| S<'StableDiffusion1ModelCheckpointConfig'>
| S<'StableDiffusion2ModelCheckpointConfig'>;
type CheckpointModelEditProps = {
modelToEdit: string;
retrievedModel: CheckpointModel;
model: CheckpointModelConfigEntity;
};
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const isBusy = useAppSelector(selectIsBusy);
const { modelToEdit, retrievedModel } = props;
const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const checkpointEditForm = useForm({
const checkpointEditForm = useForm<CheckpointModelConfig>({
initialValues: {
name: retrievedModel.model_name,
base_model: retrievedModel.base_model,
type: 'main',
path: retrievedModel.path,
description: retrievedModel.description,
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: 'checkpoint',
vae: retrievedModel.vae,
config: retrievedModel.config,
variant: retrievedModel.variant,
vae: model.vae ? model.vae : '',
config: model.config ? model.config : '',
variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
});
const editModelFormSubmitHandler = (values) => {
console.log(values);
};
const editModelFormSubmitHandler = useCallback(
(values: CheckpointModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
checkpointEditForm.setValues(payload as CheckpointModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
checkpointEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
checkpointEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? (
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name}
{model.model_name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
<ModelConvert model={retrievedModel} />
<ModelConvert model={model} />
</Flex>
<Divider />
@ -88,11 +128,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('name')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')}
/>
@ -106,36 +142,28 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
data={variantSelectData}
{...checkpointEditForm.getInputProps('variant')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.modelLocation')}
{...checkpointEditForm.getInputProps('path')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.vaeLocation')}
{...checkpointEditForm.getInputProps('vae')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')}
/>
<IAIButton disabled={isProcessing} type="submit">
<IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -1,25 +1,23 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useTranslation } from 'react-i18next';
import { useForm } from '@mantine/form';
import type { RootState } from 'app/store/store';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types';
type DiffusersModel =
| S<'StableDiffusion1ModelDiffusersConfig'>
| S<'StableDiffusion2ModelDiffusersConfig'>;
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
DiffusersModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
type DiffusersModelEditProps = {
modelToEdit: string;
retrievedModel: DiffusersModel;
model: DiffusersModelConfigEntity;
};
const baseModelSelectData = [
@ -34,39 +32,82 @@ const variantSelectData = [
];
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const { retrievedModel, modelToEdit } = props;
const isBusy = useAppSelector(selectIsBusy);
const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const diffusersEditForm = useForm({
const diffusersEditForm = useForm<DiffusersModelConfig>({
initialValues: {
name: retrievedModel.model_name,
base_model: retrievedModel.base_model,
type: 'main',
path: retrievedModel.path,
description: retrievedModel.description,
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: 'diffusers',
vae: retrievedModel.vae,
variant: retrievedModel.variant,
vae: model.vae ? model.vae : '',
variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
});
const editModelFormSubmitHandler = (values) => {
console.log(values);
};
const editModelFormSubmitHandler = useCallback(
(values: DiffusersModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
diffusersEditForm.setValues(payload as DiffusersModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
diffusersEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
diffusersEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? (
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name}
{model.model_name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
<Divider />
@ -77,11 +118,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('name')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')}
/>
@ -95,31 +132,23 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
data={variantSelectData}
{...diffusersEditForm.getInputProps('variant')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.modelLocation')}
{...diffusersEditForm.getInputProps('path')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.vaeLocation')}
{...diffusersEditForm.getInputProps('vae')}
/>
<IAIButton disabled={isProcessing} type="submit">
<IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -1,23 +1,18 @@
import {
Flex,
ListItem,
Radio,
RadioGroup,
Text,
Tooltip,
UnorderedList,
} from '@chakra-ui/react';
import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { CheckpointModel } from './CheckpointModelEdit';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps {
model: CheckpointModel;
model: CheckpointModelConfig;
}
export default function ModelConvert(props: ModelConvertProps) {
@ -26,6 +21,8 @@ export default function ModelConvert(props: ModelConvertProps) {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same');
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
@ -38,20 +35,39 @@ export default function ModelConvert(props: ModelConvertProps) {
};
const modelConvertHandler = () => {
const modelToConvert = {
model_name: model,
save_location: saveLocation,
custom_location:
saveLocation === 'custom' && customSaveLocation !== ''
? customSaveLocation
: null,
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
};
dispatch(convertToDiffusers(modelToConvert));
convertModel(responseBody)
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConverted')}: ${model.model_name}`,
status: 'success',
})
)
);
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConversionFailed')}: ${
model.model_name
}`,
status: 'error',
})
)
);
});
};
return (
<IAIAlertDialog
title={`${t('modelManager.convert')} ${model.name}`}
title={`${t('modelManager.convert')} ${model.model_name}`}
acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`}
@ -60,6 +76,7 @@ export default function ModelConvert(props: ModelConvertProps) {
size={'sm'}
aria-label={t('modelManager.convertToDiffusers')}
className=" modal-close-btn"
isLoading={isLoading}
>
🧨 {t('modelManager.convertToDiffusers')}
</IAIButton>
@ -77,7 +94,7 @@ export default function ModelConvert(props: ModelConvertProps) {
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex>
<Flex flexDir="column" gap={4}>
{/* <Flex flexDir="column" gap={4}>
<Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')}
@ -103,9 +120,9 @@ export default function ModelConvert(props: ModelConvertProps) {
</Radio>
</Flex>
</RadioGroup>
</Flex>
</Flex> */}
{saveLocation === 'custom' && (
{/* {saveLocation === 'custom' && (
<Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.customSaveLocation')}
@ -119,8 +136,7 @@ export default function ModelConvert(props: ModelConvertProps) {
width="full"
/>
</Flex>
)}
</Flex>
)} */}
</IAIAlertDialog>
);
}

View File

@ -1,185 +1,46 @@
import { Box, Flex, Spinner, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { forEach } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { useTranslation } from 'react-i18next';
type ModelListProps = {
selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void;
};
import type { ChangeEvent, ReactNode } from 'react';
import React, { useMemo, useState, useTransition } from 'react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
function ModelFilterButton({
label,
isActive,
onClick,
}: {
label: string;
isActive: boolean;
onClick: () => void;
}) {
return (
<IAIButton
onClick={onClick}
isActive={isActive}
sx={{
_active: {
bg: 'accent.750',
},
}}
size="sm"
>
{label}
</IAIButton>
);
}
const ModelList = () => {
const { data: mainModels } = useGetMainModelsQuery();
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
React.useEffect(() => {
const timer = setTimeout(() => {
setRenderModelList(true);
}, 200);
return () => clearTimeout(timer);
}, []);
const [searchText, setSearchText] = useState<string>('');
const [isSelectedFilter, setIsSelectedFilter] = useState<
'all' | 'ckpt' | 'diffusers'
>('all');
const [_, startTransition] = useTransition();
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('all');
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => {
startTransition(() => {
setSearchText(e.target.value);
});
};
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
}),
});
const renderModelListItems = useMemo(() => {
const ckptModelListItemsToRender: ReactNode[] = [];
const diffusersModelListItemsToRender: ReactNode[] = [];
const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = [];
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
}),
});
if (!mainModels) return;
const modelList = mainModels.entities;
Object.keys(modelList).forEach((model, i) => {
if (
modelList[model]?.model_name
.toLowerCase()
.includes(searchText.toLowerCase())
) {
filteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
if (modelList[model]?.model_format === isSelectedFilter) {
localFilteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
}
}
if (modelList[model]?.model_format !== 'diffusers') {
ckptModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
} else {
diffusersModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
}
});
return searchText !== '' ? (
isSelectedFilter === 'all' ? (
<Box marginTop={4}>{filteredModelListItemsToRender}</Box>
) : (
<Box marginTop={4}>{localFilteredModelListItemsToRender}</Box>
)
) : (
<Flex flexDirection="column" rowGap={6}>
{isSelectedFilter === 'all' && (
<>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
my: 4,
mx: 0,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.checkpointModels')}
</Text>
{ckptModelListItemsToRender}
</Box>
</>
)}
{isSelectedFilter === 'diffusers' && (
<Flex flexDirection="column" marginTop={4}>
{diffusersModelListItemsToRender}
</Flex>
)}
{isSelectedFilter === 'ckpt' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
</Flex>
);
}, [mainModels, searchText, t, isSelectedFilter]);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
}, []);
return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
@ -187,7 +48,6 @@ const ModelList = () => {
onChange={handleSearchFilter}
label={t('modelManager.search')}
/>
<Flex
flexDirection="column"
gap={4}
@ -195,39 +55,89 @@ const ModelList = () => {
overflow="scroll"
paddingInlineEnd={4}
>
<Flex columnGap={2}>
<ModelFilterButton
label={t('modelManager.allModels')}
onClick={() => setIsSelectedFilter('all')}
isActive={isSelectedFilter === 'all'}
/>
<ModelFilterButton
label={t('modelManager.diffusersModels')}
onClick={() => setIsSelectedFilter('diffusers')}
isActive={isSelectedFilter === 'diffusers'}
/>
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('ckpt')}
isActive={isSelectedFilter === 'ckpt'}
/>
</Flex>
{renderModelList ? (
renderModelListItems
) : (
<Flex
width="100%"
minHeight={96}
justifyContent="center"
alignItems="center"
<ButtonGroup isAttached>
<IAIButton
onClick={() => setModelFormatFilter('all')}
isChecked={modelFormatFilter === 'all'}
size="sm"
>
<Spinner />
</Flex>
)}
{t('modelManager.allModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('diffusers')}
isChecked={modelFormatFilter === 'diffusers'}
>
{t('modelManager.diffusersModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('checkpoint')}
isChecked={modelFormatFilter === 'checkpoint'}
>
{t('modelManager.checkpointModels')}
</IAIButton>
</ButtonGroup>
{['all', 'diffusers'].includes(modelFormatFilter) &&
filteredDiffusersModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Diffusers
</Text>
{filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
filteredCheckpointModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Checkpoint
</Text>
{filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
)}
</Flex>
</Flex>
);
};
export default ModelList;
const modelsFilter = (
data: EntityState<MainModelConfigEntity> | undefined,
model_format: ModelFormat,
nameFilter: string
) => {
const filteredModels: MainModelConfigEntity[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.model_name
.toLowerCase()
.includes(nameFilter.toLowerCase());
const matchesFormat = model.model_format === model_format;
if (matchesFilter && matchesFormat) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -1,98 +1,89 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { DeleteIcon } from '@chakra-ui/icons';
import { Flex, Text, Tooltip } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import { setOpenModel } from 'features/system/store/systemSlice';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useDeleteMainModelsMutation,
} from 'services/api/endpoints/models';
type ModelListItemProps = {
modelKey: string;
name: string;
description: string | undefined;
model: MainModelConfigEntity;
isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void;
};
export default function ModelListItem(props: ModelListItemProps) {
const { isProcessing, isConnected } = useAppSelector(
(state: RootState) => state.system
);
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation();
const [deleteMainModel] = useDeleteMainModelsMutation();
const dispatch = useAppDispatch();
const { model, isSelected, setSelectedModelId } = props;
const { modelKey, name, description } = props;
const handleSelectModel = useCallback(() => {
setSelectedModelId(model.id);
}, [model.id, setSelectedModelId]);
const openModelHandler = () => {
dispatch(setOpenModel(modelKey));
};
const handleModelDelete = () => {
dispatch(deleteModel(modelKey));
dispatch(setOpenModel(null));
};
const handleModelDelete = useCallback(() => {
deleteMainModel(model);
setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId]);
return (
<Flex
alignItems="center"
p={2}
borderRadius="base"
sx={
modelKey === openModel
? {
bg: 'accent.750',
_hover: {
bg: 'accent.750',
},
}
: {
_hover: {
bg: 'base.750',
},
}
}
>
<Box onClick={openModelHandler} cursor="pointer">
<Tooltip label={description} hasArrow placement="bottom">
<Text fontWeight="600">{name}</Text>
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
<Flex
as={IAIButton}
isChecked={isSelected}
sx={{
justifyContent: 'start',
p: 2,
borderRadius: 'base',
w: 'full',
alignItems: 'center',
bg: isSelected ? 'accent.400' : 'base.100',
color: isSelected ? 'base.50' : 'base.800',
_hover: {
bg: isSelected ? 'accent.500' : 'base.200',
color: isSelected ? 'base.50' : 'base.800',
},
_dark: {
color: isSelected ? 'base.50' : 'base.100',
bg: isSelected ? 'accent.600' : 'base.850',
_hover: {
color: isSelected ? 'base.50' : 'base.100',
bg: isSelected ? 'accent.550' : 'base.800',
},
},
}}
onClick={handleSelectModel}
>
<Tooltip label={model.description} hasArrow placement="bottom">
<Text sx={{ fontWeight: 500 }}>{model.model_name}</Text>
</Tooltip>
</Box>
<Spacer onClick={openModelHandler} cursor="pointer" />
<Flex gap={2} alignItems="center">
<IAIIconButton
icon={<EditIcon />}
size="sm"
onClick={openModelHandler}
aria-label={t('accessibility.modifyConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected}
/>
<IAIAlertDialog
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
triggerComponent={
<IAIIconButton
icon={<DeleteIcon />}
size="sm"
aria-label={t('modelManager.deleteConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected}
colorScheme="error"
/>
}
>
<Flex rowGap={4} flexDirection="column">
<p style={{ fontWeight: 'bold' }}>{t('modelManager.deleteMsg1')}</p>
<p>{t('modelManager.deleteMsg2')}</p>
</Flex>
</IAIAlertDialog>
</Flex>
<IAIAlertDialog
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
triggerComponent={
<IAIIconButton
icon={<DeleteIcon />}
aria-label={t('modelManager.deleteConfig')}
isDisabled={isBusy}
colorScheme="error"
/>
}
>
<Flex rowGap={4} flexDirection="column">
<p style={{ fontWeight: 'bold' }}>{t('modelManager.deleteMsg1')}</p>
<p>{t('modelManager.deleteMsg2')}</p>
</Flex>
</IAIAlertDialog>
</Flex>
);
}

View File

@ -9,17 +9,15 @@ import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useLayoutEffect } from 'react';
import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBeta';
import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta';
import { ImageDTO } from 'services/api/types';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import {
CanvasInitialImageDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { memo, useLayoutEffect } from 'react';
import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta';
import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBeta';
const selector = createSelector(
[canvasSelector, uiSelector],

View File

@ -0,0 +1,140 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
export const useMantineMultiSelectStyles = () => {
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useCallback(
() => ({
label: {
color: mode(base700, base300)(colorMode),
},
separatorLabel: {
color: mode(base500, base500)(colorMode),
'::after': { borderTopColor: mode(base300, base700)(colorMode) },
},
searchInput: {
':placeholder': {
color: mode(base300, base700)(colorMode),
},
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: mode(base900, base100)(colorMode),
},
},
}),
[
accent200,
accent300,
accent400,
accent500,
accent600,
base100,
base200,
base300,
base400,
base50,
base500,
base600,
base700,
base800,
base900,
boxShadow,
colorMode,
]
);
return styles;
};

View File

@ -0,0 +1,134 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
export const useMantineSelectStyles = () => {
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useCallback(
() => ({
label: {
color: mode(base700, base300)(colorMode),
},
separatorLabel: {
color: mode(base500, base500)(colorMode),
'::after': { borderTopColor: mode(base300, base700)(colorMode) },
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base100, base900)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,
button: {
color: mode(base900, base100)(colorMode),
},
},
}),
[
accent200,
accent300,
accent400,
accent500,
accent600,
base100,
base200,
base300,
base400,
base50,
base500,
base600,
base700,
base800,
base900,
boxShadow,
colorMode,
]
);
return styles;
};

View File

@ -1,23 +1,31 @@
import { MantineThemeOverride } from '@mantine/core';
import { useMemo } from 'react';
export const mantineTheme: MantineThemeOverride = {
colorScheme: 'dark',
fontFamily: `'Inter Variable', sans-serif`,
components: {
ScrollArea: {
defaultProps: {
scrollbarSize: 10,
},
styles: {
scrollbar: {
'&:hover': {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
export const useMantineTheme = () => {
const mantineTheme: MantineThemeOverride = useMemo(
() => ({
colorScheme: 'dark',
fontFamily: `'Inter Variable', sans-serif`,
components: {
ScrollArea: {
defaultProps: {
scrollbarSize: 10,
},
styles: {
scrollbar: {
'&:hover': {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
},
},
thumb: {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
},
},
},
thumb: {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
},
},
},
},
}),
[]
);
return mantineTheme;
};

View File

@ -4,7 +4,7 @@ import { paths } from '../schema';
type ListBoardImagesArg =
paths['/api/v1/board_images/{board_id}']['get']['parameters']['path'] &
paths['/api/v1/board_images/{board_id}']['get']['parameters']['query'];
paths['/api/v1/board_images/{board_id}']['get']['parameters']['query'];
type AddImageToBoardArg =
paths['/api/v1/board_images/']['post']['requestBody']['content']['application/json'];
@ -25,11 +25,12 @@ export const boardImagesApi = api.injectEndpoints({
query: ({ board_id, offset, limit }) => ({
url: `board_images/${board_id}`,
method: 'GET',
}),
providesTags: (result, error, arg) => {
// any list of boardimages
const tags: ApiFullTagDescription[] = [{ id: 'BoardImage', type: `${arg.board_id}_${LIST_TAG}` }];
const tags: ApiFullTagDescription[] = [
{ type: 'BoardImage', id: `${arg.board_id}_${LIST_TAG}` },
];
if (result) {
// and individual tags for each boardimage
@ -57,7 +58,7 @@ export const boardImagesApi = api.injectEndpoints({
}),
invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' },
{ type: 'Board', id: arg.board_id }
{ type: 'Board', id: arg.board_id },
],
}),
@ -69,7 +70,7 @@ export const boardImagesApi = api.injectEndpoints({
}),
invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' },
{ type: 'Board', id: arg.board_id }
{ type: 'Board', id: arg.board_id },
],
}),
}),

View File

@ -20,7 +20,7 @@ export const boardsApi = api.injectEndpoints({
query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }];
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) {
// and individual tags for each board
@ -43,7 +43,7 @@ export const boardsApi = api.injectEndpoints({
}),
providesTags: (result, error, arg) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }];
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) {
// and individual tags for each board
@ -69,7 +69,7 @@ export const boardsApi = api.injectEndpoints({
method: 'POST',
params: { board_name },
}),
invalidatesTags: [{ id: 'Board', type: LIST_TAG }],
invalidatesTags: [{ type: 'Board', id: LIST_TAG }],
}),
updateBoard: build.mutation<BoardDTO, UpdateBoardArg>({
@ -87,8 +87,15 @@ export const boardsApi = api.injectEndpoints({
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }],
}),
deleteBoardAndImages: build.mutation<void, string>({
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE', params: { include_images: true } }),
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }, { type: 'Image', id: LIST_TAG }],
query: (board_id) => ({
url: `boards/${board_id}`,
method: 'DELETE',
params: { include_images: true },
}),
invalidatesTags: (result, error, arg) => [
{ type: 'Board', id: arg },
{ type: 'Image', id: LIST_TAG },
],
}),
}),
});
@ -99,5 +106,5 @@ export const {
useCreateBoardMutation,
useUpdateBoardMutation,
useDeleteBoardMutation,
useDeleteBoardAndImagesMutation
useDeleteBoardAndImagesMutation,
} = boardsApi;

View File

@ -2,16 +2,27 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { cloneDeep } from 'lodash-es';
import {
AnyModelConfig,
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig,
DiffusersModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
} from 'services/api/types';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
export type MainModelConfigEntity = MainModelConfig & { id: string };
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity =
| DiffusersModelConfigEntity
| CheckpointModelConfigEntity;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
@ -32,6 +43,38 @@ type AnyModelConfigEntity =
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
type UpdateMainModelArg = {
base_model: BaseModelType;
model_name: string;
body: MainModelConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type DeleteMainModelArg = {
base_model: BaseModelType;
model_name: string;
};
type DeleteMainModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
};
type ConvertMainModelResponse =
paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json'];
type MergeMainModelArg = {
base_model: BaseModelType;
body: MergeModelConfig;
};
type MergeMainModelResponse =
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
@ -76,7 +119,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'MainModel', type: LIST_TAG },
{ type: 'MainModel', id: LIST_TAG },
];
if (result) {
@ -104,11 +147,58 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
updateMainModels: build.mutation<
UpdateMainModelResponse,
UpdateMainModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
DeleteMainModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
convertMainModels: build.mutation<
ConvertMainModelResponse,
ConvertMainModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/convert/${base_model}/main/${model_name}`,
method: 'PUT',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
return {
url: `models/merge/${base_model}`,
method: 'PUT',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'LoRAModel', type: LIST_TAG },
{ type: 'LoRAModel', id: LIST_TAG },
];
if (result) {
@ -143,7 +233,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'ControlNetModel', type: LIST_TAG },
{ type: 'ControlNetModel', id: LIST_TAG },
];
if (result) {
@ -175,7 +265,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'VaeModel', type: LIST_TAG },
{ type: 'VaeModel', id: LIST_TAG },
];
if (result) {
@ -210,7 +300,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'TextualInversionModel', type: LIST_TAG },
{ type: 'TextualInversionModel', id: LIST_TAG },
];
if (result) {
@ -247,4 +337,8 @@ export const {
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useUpdateMainModelsMutation,
useDeleteMainModelsMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
} = modelsApi;

View File

@ -3290,7 +3290,7 @@ export type components = {
/** ModelsList */
ModelsList: {
/** Models */
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
};
/**
* MultiplyInvocation
@ -4605,18 +4605,18 @@ export type components = {
*/
image?: components["schemas"]["ImageField"];
};
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@ -4997,7 +4997,7 @@ export type operations = {
/** @description The model imported successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
/** @description The model could not be found */
@ -5065,14 +5065,14 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
responses: {
/** @description The model was updated successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
/** @description Bad request */
@ -5106,7 +5106,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
/** @description Bad request */
@ -5141,7 +5141,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
/** @description Incompatible models */

View File

@ -42,17 +42,20 @@ export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig'];
export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig'];
export type MainModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| LoRAModelConfig
| VaeModelConfig
| ControlNetModelConfig
| TextualInversionModelConfig
| MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models'];
// Graphs
export type Graph = components['schemas']['Graph'];

View File

@ -19,16 +19,8 @@ const invokeAI = defineStyle((props) => {
bg: mode('base.200', 'base.600')(props),
color: mode('base.850', 'base.100')(props),
borderRadius: 'base',
textShadow: mode(
'0 0 0.3rem var(--invokeai-colors-base-50)',
'0 0 0.3rem var(--invokeai-colors-base-900)'
)(props),
svg: {
fill: mode('base.850', 'base.100')(props),
filter: mode(
'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-100))',
'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-800))'
)(props),
},
_hover: {
bg: mode('base.300', 'base.500')(props),
@ -57,16 +49,8 @@ const invokeAI = defineStyle((props) => {
bg: mode(`${c}.400`, `${c}.600`)(props),
color: mode(`base.50`, `base.100`)(props),
borderRadius: 'base',
textShadow: mode(
`0 0 0.3rem var(--invokeai-colors-${c}-600)`,
`0 0 0.3rem var(--invokeai-colors-${c}-800)`
)(props),
svg: {
fill: mode(`base.50`, `base.100`)(props),
filter: mode(
`drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-600))`,
`drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-800))`
)(props),
},
_disabled,
_hover: {

View File

@ -2,7 +2,7 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react';
import { mode } from '@chakra-ui/theme-tools';
const subtext = defineStyle((props) => ({
color: mode('colors.base.500', 'colors.base.400')(props),
color: mode('base.500', 'base.400')(props),
}));
export const textTheme = defineStyleConfig({