Merge branch 'main' into lstein/model-manager-route-enhancements

This commit is contained in:
Lincoln Stein 2023-07-14 13:52:55 -04:00 committed by GitHub
commit e71ce83e9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
51 changed files with 1273 additions and 983 deletions

View File

@ -5,6 +5,7 @@ from typing import Literal, Optional
import numpy import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Union
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
@ -398,8 +399,8 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to resize") image: Optional[ImageField] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on # fmt: on

View File

@ -614,6 +614,7 @@ class ModelManager(object):
rmtree(str(model_path)) rmtree(str(model_path))
else: else:
model_path.unlink() model_path.unlink()
self.commit()
# LS: tested # LS: tested
def add_model( def add_model(

View File

@ -343,6 +343,7 @@
"safetensorModels": "SafeTensors", "safetensorModels": "SafeTensors",
"modelAdded": "Model Added", "modelAdded": "Model Added",
"modelUpdated": "Model Updated", "modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted", "modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces", "cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New", "addNew": "Add New",
@ -397,8 +398,8 @@
"delete": "Delete", "delete": "Delete",
"deleteModel": "Delete Model", "deleteModel": "Delete Model",
"deleteConfig": "Delete Config", "deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?", "deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.", "deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location", "formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.", "formMessageDiffusersModelLocationDesc": "Please enter at least one.",
"formMessageDiffusersVAELocation": "VAE Location", "formMessageDiffusersVAELocation": "VAE Location",
@ -409,7 +410,7 @@
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.", "convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.", "convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.", "convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.", "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location", "convertToDiffusersSaveLocation": "Save Location",
"v1": "v1", "v1": "v1",
@ -420,12 +421,14 @@
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting", "statusConverting": "Converting",
"modelConverted": "Model Converted", "modelConverted": "Model Converted",
"modelConversionFailed": "Model Conversion Failed",
"sameFolder": "Same folder", "sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder", "invokeRoot": "InvokeAI folder",
"custom": "Custom", "custom": "Custom",
"customSaveLocation": "Custom Save Location", "customSaveLocation": "Custom Save Location",
"merge": "Merge", "merge": "Merge",
"modelsMerged": "Models Merged", "modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models", "mergeModels": "Merge Models",
"modelOne": "Model 1", "modelOne": "Model 1",
"modelTwo": "Model 2", "modelTwo": "Model 2",
@ -446,7 +449,8 @@
"weightedSum": "Weighted Sum", "weightedSum": "Weighted Sum",
"none": "none", "none": "none",
"addDifference": "Add Difference", "addDifference": "Add Difference",
"pickModelType": "Pick Model Type" "pickModelType": "Pick Model Type",
"selectModel": "Select Model"
}, },
"parameters": { "parameters": {
"general": "General", "general": "General",
@ -599,7 +603,6 @@
"nodesLoaded": "Nodes Loaded", "nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes", "nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared" "nodesCleared": "Nodes Cleared"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {

View File

@ -9,9 +9,9 @@ import { theme as invokeAITheme } from 'theme/theme';
import '@fontsource-variable/inter'; import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core'; import { MantineProvider } from '@mantine/core';
import { mantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css'; import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css'; import 'theme/css/overlayscrollbars.css';
import { useMantineTheme } from 'mantine-theme/theme';
type ThemeLocaleProviderProps = { type ThemeLocaleProviderProps = {
children: ReactNode; children: ReactNode;
@ -35,8 +35,10 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
document.body.dir = direction; document.body.dir = direction;
}, [direction]); }, [direction]);
const mantineTheme = useMantineTheme();
return ( return (
<MantineProvider withGlobalStyles theme={mantineTheme}> <MantineProvider theme={mantineTheme}>
<ChakraProvider theme={theme} colorModeManager={manager}> <ChakraProvider theme={theme} colorModeManager={manager}>
{children} {children}
</ChakraProvider> </ChakraProvider>

View File

@ -1,6 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { modelsApi } from 'services/api/endpoints/models';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
@ -23,6 +24,13 @@ export const addSocketConnectedEventListener = () => {
// pass along the socket event as an application action // pass along the socket event as an application action
dispatch(appSocketConnected(action.payload)); dispatch(appSocketConnected(action.payload));
// update all server state
dispatch(modelsApi.endpoints.getMainModels.initiate());
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
dispatch(modelsApi.endpoints.getVaeModels.initiate());
}, },
}); });
}; };

View File

@ -100,10 +100,11 @@ export const store = configureStore({
// manually type state, cannot type the arg // manually type state, cannot type the arg
// const typedState = state as ReturnType<typeof rootReducer>; // const typedState = state as ReturnType<typeof rootReducer>;
if (action.type.startsWith('api/')) { // TODO: doing this breaks the rtk query devtools, commenting out for now
// don't log api actions, with manual cache updates they are extremely noisy // if (action.type.startsWith('api/')) {
return false; // // don't log api actions, with manual cache updates they are extremely noisy
} // return false;
// }
if (actionsDenylist.includes(action.type)) { if (actionsDenylist.includes(action.type)) {
// don't log other noisy actions // don't log other noisy actions

View File

@ -1,11 +1,13 @@
import { Flex } from '@chakra-ui/react'; import { Flex, useColorMode } from '@chakra-ui/react';
import { ReactElement } from 'react'; import { ReactElement } from 'react';
import { mode } from 'theme/util/mode';
export function IAIFormItemWrapper({ export function IAIFormItemWrapper({
children, children,
}: { }: {
children: ReactElement | ReactElement[]; children: ReactElement | ReactElement[];
}) { }) {
const { colorMode } = useColorMode();
return ( return (
<Flex <Flex
sx={{ sx={{
@ -14,7 +16,7 @@ export function IAIFormItemWrapper({
rowGap: 4, rowGap: 4,
borderRadius: 'base', borderRadius: 'base',
width: 'full', width: 'full',
bg: 'base.900', bg: mode('base.100', 'base.900')(colorMode),
}} }}
> >
{children} {children}

View File

@ -0,0 +1,44 @@
import { useColorMode } from '@chakra-ui/react';
import { TextInput, TextInputProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { mode } from 'theme/util/mode';
type IAIMantineTextInputProps = TextInputProps;
export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
const { ...rest } = props;
const {
base50,
base100,
base200,
base300,
base800,
base700,
base900,
accent500,
accent300,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
return (
<TextInput
styles={() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
})}
{...rest}
/>
);
}

View File

@ -1,10 +1,9 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react'; import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & { type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string; tooltip?: string;
@ -14,25 +13,6 @@ type IAIMultiSelectProps = MultiSelectProps & {
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props; const { searchable = true, tooltip, inputRef, ...rest } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const { colorMode } = useColorMode();
const handleKeyDown = useCallback( const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => { (e: KeyboardEvent<HTMLInputElement>) => {
@ -52,6 +32,8 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
[dispatch] [dispatch]
); );
const styles = useMantineMultiSelectStyles();
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}> <Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect <MultiSelect
@ -60,92 +42,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
onKeyUp={handleKeyUp} onKeyUp={handleKeyUp}
searchable={searchable} searchable={searchable}
maxDropdownHeight={300} maxDropdownHeight={300}
styles={() => ({ styles={styles}
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
searchInput: {
':placeholder': {
color: mode(base300, base700)(colorMode),
},
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
{...rest} {...rest}
/> />
</Tooltip> </Tooltip>

View File

@ -0,0 +1,78 @@
import { Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
export type IAISelectDataType = {
value: string;
label: string;
tooltip?: string;
};
type IAISelectProps = SelectProps & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={styles}
{...rest}
/>
</Tooltip>
);
};
export default memo(IAIMantineSearchableSelect);

View File

@ -1,10 +1,7 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core'; import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { RefObject, memo } from 'react';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
import { mode } from 'theme/util/mode';
export type IAISelectDataType = { export type IAISelectDataType = {
value: string; value: string;
@ -18,157 +15,13 @@ type IAISelectProps = SelectProps & {
}; };
const IAIMantineSelect = (props: IAISelectProps) => { const IAIMantineSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props; const { tooltip, inputRef, ...rest } = props;
const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode(); const styles = useMantineSelectStyles();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const [boxShadow] = useToken('shadows', ['dark-lg']);
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<Select <Select ref={inputRef} styles={styles} {...rest} />
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={() => ({
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base100, base900)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
{...rest}
/>
</Tooltip> </Tooltip>
); );
}; };

View File

@ -24,7 +24,7 @@ import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { import {
canvasCopiedToClipboard, canvasCopiedToClipboard,
canvasDownloadedAsImage, canvasDownloadedAsImage,
@ -213,7 +213,7 @@ const IAICanvasToolbar = () => {
}} }}
> >
<Box w={24}> <Box w={24}>
<IAIMantineSelect <IAIMantineSearchableSelect
tooltip={`${t('unifiedCanvas.layer')} (Q)`} tooltip={`${t('unifiedCanvas.layer')} (Q)`}
value={layer} value={layer}
data={LAYER_NAMES_DICT} data={LAYER_NAMES_DICT}

View File

@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, { import IAIMantineSearchableSelect, {
IAISelectDataType, IAISelectDataType,
} from 'common/components/IAIMantineSelect'; } from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { import {
CONTROLNET_MODELS, CONTROLNET_MODELS,
@ -48,7 +48,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
data={controlNetModels} data={controlNetModels}
value={model} value={model}
onChange={handleModelChanged} onChange={handleModelChanged}

View File

@ -1,8 +1,12 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, { import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType, IAISelectDataType,
} from 'common/components/IAIMantineSelect'; } from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants'; import { CONTROLNET_PROCESSORS } from '../../store/constants';
@ -11,10 +15,6 @@ import {
ControlNetProcessorNode, ControlNetProcessorNode,
ControlNetProcessorType, ControlNetProcessorType,
} from '../../store/types'; } from '../../store/types';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { createSelector } from '@reduxjs/toolkit';
import { configSelector } from 'features/system/store/configSelectors';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
type ParamControlNetProcessorSelectProps = { type ParamControlNetProcessorSelectProps = {
controlNetId: string; controlNetId: string;
@ -72,7 +72,7 @@ const ParamControlNetProcessorSelect = (
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
label="Processor" label="Processor"
value={processorNode.type ?? 'canny_image_processor'} value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors} data={controlNetProcessors}

View File

@ -9,7 +9,7 @@ import {
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
@ -106,7 +106,7 @@ const ParamEmbeddingPopover = (props: Props) => {
</Text> </Text>
</Flex> </Flex>
) : ( ) : (
<IAIMantineSelect <IAIMantineSearchableSelect
inputRef={inputRef} inputRef={inputRef}
autoFocus autoFocus
placeholder={'Add Embedding'} placeholder={'Add Embedding'}

View File

@ -12,10 +12,10 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { memo, useContext, useRef, useState } from 'react'; import { memo, useContext, useRef, useState } from 'react';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
const UpdateImageBoardModal = () => { const UpdateImageBoardModal = () => {
// const boards = useSelector(selectBoardsAll); // const boards = useSelector(selectBoardsAll);
@ -56,7 +56,7 @@ const UpdateImageBoardModal = () => {
{isFetching ? ( {isFetching ? (
<Spinner /> <Spinner />
) : ( ) : (
<IAIMantineSelect <IAIMantineSearchableSelect
placeholder="Select Board" placeholder="Select Board"
onChange={(v) => setSelectedBoard(v)} onChange={(v) => setSelectedBoard(v)}
value={selectedBoard} value={selectedBoard}

View File

@ -4,7 +4,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { RootState, stateSelector } from 'app/store/store'; import { RootState, stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { loraAdded } from 'features/lora/store/loraSlice'; import { loraAdded } from 'features/lora/store/loraSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
@ -84,7 +84,7 @@ const ParamLoRASelect = () => {
} }
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'} placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
value={null} value={null}
data={data} data={data}

View File

@ -3,7 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { forwardRef, useCallback } from 'react'; import { forwardRef, useCallback } from 'react';
import 'reactflow/dist/style.css'; import 'reactflow/dist/style.css';
@ -77,7 +77,7 @@ const AddNodeMenu = () => {
return ( return (
<Flex sx={{ gap: 2, alignItems: 'center' }}> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAIMantineSelect <IAIMantineSearchableSelect
selectOnBlur={false} selectOnBlur={false}
placeholder="Add Node" placeholder="Add Node"
value={null} value={null}

View File

@ -1,7 +1,7 @@
import { Flex, Text } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
@ -89,7 +89,7 @@ const LoRAModelInputFieldComponent = (
} }
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
value={selectedLoRAModel?.id ?? null} value={selectedLoRAModel?.id ?? null}
label={ label={
selectedLoRAModel?.base_model && selectedLoRAModel?.base_model &&

View File

@ -6,7 +6,7 @@ import {
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
@ -81,14 +81,14 @@ const ModelInputFieldComponent = (
); );
return isLoading ? ( return isLoading ? (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('modelManager.model')} label={t('modelManager.model')}
placeholder="Loading..." placeholder="Loading..."
disabled={true} disabled={true}
data={[]} data={[]}
/> />
) : ( ) : (
<IAIMantineSelect <IAIMantineSearchableSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={ label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model] selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]

View File

@ -1,6 +1,6 @@
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
@ -86,7 +86,7 @@ const VaeModelInputFieldComponent = (
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description} tooltip={selectedVaeModel?.description}
label={ label={

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice'; import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice';
import { import {
@ -35,7 +35,7 @@ const ParamScaleBeforeProcessing = () => {
}; };
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('parameters.scaleBeforeProcessing')} label={t('parameters.scaleBeforeProcessing')}
data={BOUNDING_BOX_SCALES_DICT} data={BOUNDING_BOX_SCALES_DICT}
value={boundingBoxScale} value={boundingBoxScale}

View File

@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice'; import { setScheduler } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
@ -48,7 +48,7 @@ const ParamScheduler = () => {
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('parameters.scheduler')} label={t('parameters.scheduler')}
value={scheduler} value={scheduler}
data={data} data={data}

View File

@ -1,7 +1,7 @@
import { FACETOOL_TYPES } from 'app/constants'; import { FACETOOL_TYPES } from 'app/constants';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { import {
FacetoolType, FacetoolType,
setFacetoolType, setFacetoolType,
@ -20,7 +20,7 @@ export default function FaceRestoreType() {
dispatch(setFacetoolType(v as FacetoolType)); dispatch(setFacetoolType(v as FacetoolType));
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('parameters.type')} label={t('parameters.type')}
data={FACETOOL_TYPES.concat()} data={FACETOOL_TYPES.concat()}
value={facetoolType} value={facetoolType}

View File

@ -2,7 +2,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
@ -77,14 +77,14 @@ const ParamMainModelSelect = () => {
); );
return isLoading ? ( return isLoading ? (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('modelManager.model')} label={t('modelManager.model')}
placeholder="Loading..." placeholder="Loading..."
disabled={true} disabled={true}
data={[]} data={[]}
/> />
) : ( ) : (
<IAIMantineSelect <IAIMantineSearchableSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.model')} label={t('modelManager.model')}
value={selectedModel?.id} value={selectedModel?.id}

View File

@ -1,7 +1,7 @@
import { UPSCALING_LEVELS } from 'app/constants'; import { UPSCALING_LEVELS } from 'app/constants';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { import {
UpscalingLevel, UpscalingLevel,
setUpscalingLevel, setUpscalingLevel,
@ -24,7 +24,7 @@ export default function UpscaleScale() {
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel)); dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
disabled={!isESRGANAvailable} disabled={!isESRGANAvailable}
label={t('parameters.scale')} label={t('parameters.scale')}
value={String(upscalingLevel)} value={String(upscalingLevel)}

View File

@ -2,7 +2,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
@ -92,7 +92,7 @@ const ParamVAEModelSelect = () => {
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description} tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')} label={t('modelManager.vae')}

View File

@ -13,7 +13,6 @@ export const systemPersistDenylist: (keyof SystemState)[] = [
'isProcessing', 'isProcessing',
'totalIterations', 'totalIterations',
'totalSteps', 'totalSteps',
'openModel',
'isCancelScheduled', 'isCancelScheduled',
'progressImage', 'progressImage',
'wereModelsReceived', 'wereModelsReceived',

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { reduce, pickBy } from 'lodash-es'; import { pickBy, reduce } from 'lodash-es';
export const systemSelector = (state: RootState) => state.system; export const systemSelector = (state: RootState) => state.system;
@ -50,3 +50,8 @@ export const languageSelector = createSelector(
export const isProcessingSelector = (state: RootState) => export const isProcessingSelector = (state: RootState) =>
state.system.isProcessing; state.system.isProcessing;
export const selectIsBusy = createSelector(
(state: RootState) => state,
(state) => state.system.isProcessing || !state.system.isConnected
);

View File

@ -46,7 +46,6 @@ export interface SystemState {
toastQueue: UseToastOptions[]; toastQueue: UseToastOptions[];
searchFolder: string | null; searchFolder: string | null;
foundModels: InvokeAI.FoundModel[] | null; foundModels: InvokeAI.FoundModel[] | null;
openModel: string | null;
/** /**
* The current progress image * The current progress image
*/ */
@ -109,7 +108,6 @@ export const initialSystemState: SystemState = {
toastQueue: [], toastQueue: [],
searchFolder: null, searchFolder: null,
foundModels: null, foundModels: null,
openModel: null,
progressImage: null, progressImage: null,
shouldAntialiasProgressImage: false, shouldAntialiasProgressImage: false,
sessionId: null, sessionId: null,
@ -164,9 +162,6 @@ export const systemSlice = createSlice({
) => { ) => {
state.foundModels = action.payload; state.foundModels = action.payload;
}, },
setOpenModel: (state, action: PayloadAction<string | null>) => {
state.openModel = action.payload;
},
/** /**
* A cancel was scheduled * A cancel was scheduled
*/ */
@ -433,7 +428,6 @@ export const {
clearToastQueue, clearToastQueue,
setSearchFolder, setSearchFolder,
setFoundModels, setFoundModels,
setOpenModel,
cancelScheduled, cancelScheduled,
scheduledCancelAborted, scheduledCancelAborted,
cancelTypeChanged, cancelTypeChanged,

View File

@ -13,7 +13,7 @@ type ModelManagerTabInfo = {
content: ReactNode; content: ReactNode;
}; };
const modelManagerTabs: ModelManagerTabInfo[] = [ const tabs: ModelManagerTabInfo[] = [
{ {
id: 'modelManager', id: 'modelManager',
label: i18n.t('modelManager.modelManager'), label: i18n.t('modelManager.modelManager'),
@ -31,49 +31,28 @@ const modelManagerTabs: ModelManagerTabInfo[] = [
}, },
]; ];
const renderTabsList = () => {
const modelManagerTabListsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabListsToRender.push(
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
);
});
return (
<TabList
sx={{
w: '100%',
color: 'base.200',
flexDirection: 'row',
borderBottomWidth: 2,
borderColor: 'accent.700',
}}
>
{modelManagerTabListsToRender}
</TabList>
);
};
const renderTabPanels = () => {
const modelManagerTabPanelsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabPanelsToRender.push(
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
);
});
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
};
const ModelManagerTab = () => { const ModelManagerTab = () => {
return ( return (
<Tabs <Tabs
isLazy isLazy
variant="invokeAI" variant="line"
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }} layerStyle="first"
sx={{ w: 'full', h: 'full', p: 4, gap: 4, borderRadius: 'base' }}
> >
{renderTabsList()} <TabList>
{renderTabPanels()} {tabs.map((tab) => (
<Tab sx={{ borderTopRadius: 'base' }} key={tab.id}>
{tab.label}
</Tab>
))}
</TabList>
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}>
{tabs.map((tab) => (
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
{tab.content}
</TabPanel>
))}
</TabPanels>
</Tabs> </Tabs>
); );
}; };

View File

@ -1,4 +1,4 @@
import { Divider, Flex } from '@chakra-ui/react'; import { Divider, Flex, useColorMode } from '@chakra-ui/react';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
@ -12,6 +12,8 @@ export default function AddModelsPanel() {
(state: RootState) => state.ui.addNewModelUIOption (state: RootState) => state.ui.addNewModelUIOption
); );
const { colorMode } = useColorMode();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -20,27 +22,13 @@ export default function AddModelsPanel() {
<Flex columnGap={4}> <Flex columnGap={4}>
<IAIButton <IAIButton
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))} onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
sx={{ isChecked={addNewModelUIOption == 'ckpt'}
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600',
},
}}
> >
{t('modelManager.addCheckpointModel')} {t('modelManager.addCheckpointModel')}
</IAIButton> </IAIButton>
<IAIButton <IAIButton
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))} onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
sx={{ isChecked={addNewModelUIOption == 'diffusers'}
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600',
},
}}
> >
{t('modelManager.addDiffuserModel')} {t('modelManager.addDiffuserModel')}
</IAIButton> </IAIButton>

View File

@ -66,13 +66,13 @@ export default function AddDiffusersModel() {
}; };
return ( return (
<Flex overflow="scroll" maxHeight={window.innerHeight - 270}> <Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
<Formik <Formik
initialValues={addModelFormValues} initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler} onSubmit={addModelFormSubmitHandler}
> >
{({ handleSubmit, errors, touched }) => ( {({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}> <IAIForm onSubmit={handleSubmit} w="full">
<VStack rowGap={2}> <VStack rowGap={2}>
<IAIFormItemWrapper> <IAIFormItemWrapper>
{/* Name */} {/* Name */}
@ -90,7 +90,6 @@ export default function AddDiffusersModel() {
name="name" name="name"
type="text" type="text"
validate={baseValidation} validate={baseValidation}
width="2xl"
isRequired isRequired
/> />
{!!errors.name && touched.name ? ( {!!errors.name && touched.name ? (
@ -119,7 +118,6 @@ export default function AddDiffusersModel() {
id="description" id="description"
name="description" name="description"
type="text" type="text"
width="2xl"
isRequired isRequired
/> />
{!!errors.description && touched.description ? ( {!!errors.description && touched.description ? (
@ -153,13 +151,7 @@ export default function AddDiffusersModel() {
{t('modelManager.modelLocation')} {t('modelManager.modelLocation')}
</FormLabel> </FormLabel>
<VStack alignItems="start"> <VStack alignItems="start">
<Field <Field as={IAIInput} id="path" name="path" type="text" />
as={IAIInput}
id="path"
name="path"
type="text"
width="2xl"
/>
{!!errors.path && touched.path ? ( {!!errors.path && touched.path ? (
<FormErrorMessage>{errors.path}</FormErrorMessage> <FormErrorMessage>{errors.path}</FormErrorMessage>
) : ( ) : (
@ -181,7 +173,6 @@ export default function AddDiffusersModel() {
id="repo_id" id="repo_id"
name="repo_id" name="repo_id"
type="text" type="text"
width="2xl"
/> />
{!!errors.repo_id && touched.repo_id ? ( {!!errors.repo_id && touched.repo_id ? (
<FormErrorMessage>{errors.repo_id}</FormErrorMessage> <FormErrorMessage>{errors.repo_id}</FormErrorMessage>
@ -220,7 +211,6 @@ export default function AddDiffusersModel() {
id="vae.path" id="vae.path"
name="vae.path" name="vae.path"
type="text" type="text"
width="2xl"
/> />
{!!errors.vae?.path && touched.vae?.path ? ( {!!errors.vae?.path && touched.vae?.path ? (
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage> <FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
@ -245,7 +235,6 @@ export default function AddDiffusersModel() {
id="vae.repo_id" id="vae.repo_id"
name="vae.repo_id" name="vae.repo_id"
type="text" type="text"
width="2xl"
/> />
{!!errors.vae?.repo_id && touched.vae?.repo_id ? ( {!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage> <FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>

View File

@ -1,42 +1,85 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; import {
import { RootState } from 'app/store/store'; Flex,
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; Radio,
RadioGroup,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice';
import { pickBy } from 'lodash-es'; import { pickBy } from 'lodash-es';
import { useState } from 'react'; import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import {
useGetMainModelsQuery,
useMergeMainModelsMutation,
} from 'services/api/endpoints/models';
import { BaseModelType, MergeModelConfig } from 'services/api/types';
import { mode } from 'theme/util/mode';
const baseModelTypeSelectData = [
{ label: 'Stable Diffusion 1', value: 'sd-1' },
{ label: 'Stable Diffusion 2', value: 'sd-2' },
];
type MergeInterpolationMethods =
| 'weighted_sum'
| 'sigmoid'
| 'inv_sigmoid'
| 'add_difference';
export default function MergeModelsPanel() { export default function MergeModelsPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
const { colorMode } = useColorMode();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery(); const { data } = useGetMainModelsQuery();
const diffusersModels = pickBy( const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
const sd1DiffusersModels = pickBy(
data?.entities, data?.entities,
(value, _) => value?.model_format === 'diffusers' (value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
); );
const [modelOne, setModelOne] = useState<string>( const sd2DiffusersModels = pickBy(
Object.keys(diffusersModels)[0] data?.entities,
(value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
); );
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1] const modelsMap = useMemo(() => {
return {
'sd-1': sd1DiffusersModels,
'sd-2': sd2DiffusersModels,
};
}, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel])[0]
); );
const [modelThree, setModelThree] = useState<string>('none'); const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel])[1]
);
const [modelThree, setModelThree] = useState<string | null>(null);
const [mergedModelName, setMergedModelName] = useState<string>(''); const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5); const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState< const [modelMergeInterp, setModelMergeInterp] =
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference' useState<MergeInterpolationMethods>('weighted_sum');
>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState< const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
'root' | 'custom' 'root' | 'custom'
@ -47,41 +90,73 @@ export default function MergeModelsPanel() {
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false); const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter( const modelOneList = Object.keys(
(model) => model !== modelTwo && model !== modelThree modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelTwo && model !== modelThree);
const modelTwoList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelThree);
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
(model) => model !== modelOne && model !== modelTwo
); );
const modelTwoList = Object.keys(diffusersModels).filter( const handleBaseModelChange = (v: string) => {
(model) => model !== modelOne && model !== modelThree setBaseModel(v as BaseModelType);
); setModelOne(null);
setModelTwo(null);
const modelThreeList = [ };
{ key: t('modelManager.none'), value: 'none' },
...Object.keys(diffusersModels)
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ key: model, value: model })),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const mergeModelsHandler = () => { const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; const models_names: string[] = [];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
models_to_merge: modelsToMerge, modelsToMerge = modelsToMerge.filter((model) => model !== null);
modelsToMerge.forEach((model) => {
if (model) {
models_names.push(model?.split('/')[2]);
}
});
const mergeModelsInfo: MergeModelConfig = {
model_names: models_names,
merged_model_name: merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha, alpha: modelMergeAlpha,
interp: modelMergeInterp, interp: modelMergeInterp,
model_merge_save_path: // model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, // modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce, force: modelMergeForce,
}; };
dispatch(mergeDiffusersModels(mergeModelsInfo)); mergeModels({
base_model: baseModel,
body: mergeModelsInfo,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMerged'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMergeFailed'),
status: 'error',
})
)
);
}
});
}; };
return ( return (
@ -90,7 +165,6 @@ export default function MergeModelsPanel() {
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
rowGap: 1, rowGap: 1,
bg: 'base.900',
}} }}
> >
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text> <Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
@ -98,26 +172,43 @@ export default function MergeModelsPanel() {
{t('modelManager.modelMergeHeaderHelp2')} {t('modelManager.modelMergeHeaderHelp2')}
</Text> </Text>
</Flex> </Flex>
<Flex columnGap={4}> <Flex columnGap={4}>
<IAISelect <IAIMantineSelect
label="Model Type"
w="100%"
data={baseModelTypeSelectData}
value={baseModel}
onChange={handleBaseModelChange}
/>
<IAIMantineSearchableSelect
label={t('modelManager.modelOne')} label={t('modelManager.modelOne')}
validValues={modelOneList} w="100%"
onChange={(e) => setModelOne(e.target.value)} value={modelOne}
placeholder={t('modelManager.selectModel')}
data={modelOneList}
onChange={(v) => setModelOne(v)}
/> />
<IAISelect <IAIMantineSearchableSelect
label={t('modelManager.modelTwo')} label={t('modelManager.modelTwo')}
validValues={modelTwoList} w="100%"
onChange={(e) => setModelTwo(e.target.value)} placeholder={t('modelManager.selectModel')}
value={modelTwo}
data={modelTwoList}
onChange={(v) => setModelTwo(v)}
/> />
<IAISelect <IAIMantineSearchableSelect
label={t('modelManager.modelThree')} label={t('modelManager.modelThree')}
validValues={modelThreeList} data={modelThreeList}
onChange={(e) => { w="100%"
if (e.target.value !== 'none') { placeholder={t('modelManager.selectModel')}
setModelThree(e.target.value); clearable
onChange={(v) => {
if (!v) {
setModelThree(null);
setModelMergeInterp('add_difference'); setModelMergeInterp('add_difference');
} else { } else {
setModelThree('none'); setModelThree(v);
setModelMergeInterp('weighted_sum'); setModelMergeInterp('weighted_sum');
} }
}} }}
@ -136,7 +227,7 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: 'base.900', bg: mode('base.100', 'base.800')(colorMode),
}} }}
> >
<IAISlider <IAISlider
@ -161,7 +252,7 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: 'base.900', bg: mode('base.100', 'base.800')(colorMode),
}} }}
> >
<Text fontWeight={500} fontSize="sm" variant="subtext"> <Text fontWeight={500} fontSize="sm" variant="subtext">
@ -169,12 +260,10 @@ export default function MergeModelsPanel() {
</Text> </Text>
<RadioGroup <RadioGroup
value={modelMergeInterp} value={modelMergeInterp}
onChange={( onChange={(v: MergeInterpolationMethods) => setModelMergeInterp(v)}
v: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
) => setModelMergeInterp(v)}
> >
<Flex columnGap={4}> <Flex columnGap={4}>
{modelThree === 'none' ? ( {modelThree === null ? (
<> <>
<Radio value="weighted_sum"> <Radio value="weighted_sum">
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text> <Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
@ -199,7 +288,7 @@ export default function MergeModelsPanel() {
</RadioGroup> </RadioGroup>
</Flex> </Flex>
<Flex {/* <Flex
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
padding: 4, padding: 4,
@ -235,7 +324,7 @@ export default function MergeModelsPanel() {
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)} onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/> />
)} )}
</Flex> </Flex> */}
<IAISimpleCheckbox <IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')} label={t('modelManager.ignoreMismatch')}
@ -246,10 +335,8 @@ export default function MergeModelsPanel() {
<IAIButton <IAIButton
onClick={mergeModelsHandler} onClick={mergeModelsHandler}
isLoading={isProcessing} isLoading={isLoading}
isDisabled={ isDisabled={modelOne === null || modelTwo === null}
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
}
> >
{t('modelManager.merge')} {t('modelManager.merge')}
</IAIButton> </IAIButton>

View File

@ -1,44 +1,60 @@
import { Flex } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useState } from 'react';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const { data: mainModels } = useGetMainModelsQuery(); const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const renderModelEditTabs = () => {
if (!openModel || !mainModels) return;
if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
return ( return (
<DiffusersModelEdit <Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
modelToEdit={openModel} <ModelList
retrievedModel={mainModels['entities'][openModel]} selectedModelId={selectedModelId}
key={openModel} setSelectedModelId={setSelectedModelId}
/> />
); <ModelEdit model={model} />
} else {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={mainModels['entities'][openModel]}
key={openModel}
/>
);
}
};
return (
<Flex width="100%" columnGap={8}>
<ModelList />
{renderModelEditTabs()}
</Flex> </Flex>
); );
} }
type ModelEditProps = {
model: MainModelConfigEntity | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
const { model } = props;
if (model?.model_format === 'checkpoint') {
return <CheckpointModelEdit key={model.id} model={model} />;
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model} />;
}
return (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
maxH: 96,
userSelect: 'none',
}}
>
<Text variant="subtext">No Model Selected</Text>
</Flex>
);
};

View File

@ -1,17 +1,20 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react'; import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import { useTranslation } from 'react-i18next'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
CheckpointModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';
const baseModelSelectData = [ const baseModelSelectData = [
@ -25,55 +28,92 @@ const variantSelectData = [
{ value: 'depth', label: 'Depth' }, { value: 'depth', label: 'Depth' },
]; ];
export type CheckpointModel =
| S<'StableDiffusion1ModelCheckpointConfig'>
| S<'StableDiffusion2ModelCheckpointConfig'>;
type CheckpointModelEditProps = { type CheckpointModelEditProps = {
modelToEdit: string; model: CheckpointModelConfigEntity;
retrievedModel: CheckpointModel;
}; };
export default function CheckpointModelEdit(props: CheckpointModelEditProps) { export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isProcessing = useAppSelector( const isBusy = useAppSelector(selectIsBusy);
(state: RootState) => state.system.isProcessing
);
const { modelToEdit, retrievedModel } = props; const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const checkpointEditForm = useForm({ const checkpointEditForm = useForm<CheckpointModelConfig>({
initialValues: { initialValues: {
name: retrievedModel.model_name, model_name: model.model_name ? model.model_name : '',
base_model: retrievedModel.base_model, base_model: model.base_model,
type: 'main', model_type: 'main',
path: retrievedModel.path, path: model.path ? model.path : '',
description: retrievedModel.description, description: model.description ? model.description : '',
model_format: 'checkpoint', model_format: 'checkpoint',
vae: retrievedModel.vae, vae: model.vae ? model.vae : '',
config: retrievedModel.config, config: model.config ? model.config : '',
variant: retrievedModel.variant, variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
}, },
}); });
const editModelFormSubmitHandler = (values) => { const editModelFormSubmitHandler = useCallback(
console.log(values); (values: CheckpointModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
}; };
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
checkpointEditForm.setValues(payload as CheckpointModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
checkpointEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
checkpointEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? ( return (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column"> <Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name} {model.model_name}
</Text> </Text>
<Text fontSize="sm" color="base.400"> <Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model {MODEL_TYPE_MAP[model.base_model]} Model
</Text> </Text>
</Flex> </Flex>
<ModelConvert model={retrievedModel} /> <ModelConvert model={model} />
</Flex> </Flex>
<Divider /> <Divider />
@ -88,11 +128,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput <IAIMantineTextInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('name')}
/>
<IAIInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')} {...checkpointEditForm.getInputProps('description')}
/> />
@ -106,36 +142,28 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
data={variantSelectData} data={variantSelectData}
{...checkpointEditForm.getInputProps('variant')} {...checkpointEditForm.getInputProps('variant')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.modelLocation')} label={t('modelManager.modelLocation')}
{...checkpointEditForm.getInputProps('path')} {...checkpointEditForm.getInputProps('path')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.vaeLocation')} label={t('modelManager.vaeLocation')}
{...checkpointEditForm.getInputProps('vae')} {...checkpointEditForm.getInputProps('vae')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.config')} label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')} {...checkpointEditForm.getInputProps('config')}
/> />
<IAIButton disabled={isProcessing} type="submit"> <IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')} {t('modelManager.updateModel')}
</IAIButton> </IAIButton>
</Flex> </Flex>
</form> </form>
</Flex> </Flex>
</Flex> </Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
); );
} }

View File

@ -1,25 +1,23 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react'; import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useTranslation } from 'react-i18next';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import type { RootState } from 'app/store/store'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
type DiffusersModel = import { useCallback } from 'react';
| S<'StableDiffusion1ModelDiffusersConfig'> import { useTranslation } from 'react-i18next';
| S<'StableDiffusion2ModelDiffusersConfig'>; import {
DiffusersModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
type DiffusersModelEditProps = { type DiffusersModelEditProps = {
modelToEdit: string; model: DiffusersModelConfigEntity;
retrievedModel: DiffusersModel;
}; };
const baseModelSelectData = [ const baseModelSelectData = [
@ -34,39 +32,82 @@ const variantSelectData = [
]; ];
export default function DiffusersModelEdit(props: DiffusersModelEditProps) { export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isProcessing = useAppSelector( const isBusy = useAppSelector(selectIsBusy);
(state: RootState) => state.system.isProcessing
); const { model } = props;
const { retrievedModel, modelToEdit } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const diffusersEditForm = useForm({ const diffusersEditForm = useForm<DiffusersModelConfig>({
initialValues: { initialValues: {
name: retrievedModel.model_name, model_name: model.model_name ? model.model_name : '',
base_model: retrievedModel.base_model, base_model: model.base_model,
type: 'main', model_type: 'main',
path: retrievedModel.path, path: model.path ? model.path : '',
description: retrievedModel.description, description: model.description ? model.description : '',
model_format: 'diffusers', model_format: 'diffusers',
vae: retrievedModel.vae, vae: model.vae ? model.vae : '',
variant: retrievedModel.variant, variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
}, },
}); });
const editModelFormSubmitHandler = (values) => { const editModelFormSubmitHandler = useCallback(
console.log(values); (values: DiffusersModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
}; };
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
diffusersEditForm.setValues(payload as DiffusersModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
diffusersEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
diffusersEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? ( return (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column"> <Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name} {model.model_name}
</Text> </Text>
<Text fontSize="sm" color="base.400"> <Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model {MODEL_TYPE_MAP[model.base_model]} Model
</Text> </Text>
</Flex> </Flex>
<Divider /> <Divider />
@ -77,11 +118,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput <IAIMantineTextInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('name')}
/>
<IAIInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')} {...diffusersEditForm.getInputProps('description')}
/> />
@ -95,31 +132,23 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
data={variantSelectData} data={variantSelectData}
{...diffusersEditForm.getInputProps('variant')} {...diffusersEditForm.getInputProps('variant')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.modelLocation')} label={t('modelManager.modelLocation')}
{...diffusersEditForm.getInputProps('path')} {...diffusersEditForm.getInputProps('path')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.vaeLocation')} label={t('modelManager.vaeLocation')}
{...diffusersEditForm.getInputProps('vae')} {...diffusersEditForm.getInputProps('vae')}
/> />
<IAIButton disabled={isProcessing} type="submit"> <IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')} {t('modelManager.updateModel')}
</IAIButton> </IAIButton>
</Flex> </Flex>
</form> </form>
</Flex> </Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
); );
} }

View File

@ -1,23 +1,18 @@
import { import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react';
Flex,
ListItem,
Radio,
RadioGroup,
Text,
Tooltip,
UnorderedList,
} from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions'; // import { convertToDiffusers } from 'app/socketio/actions';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { CheckpointModel } from './CheckpointModelEdit';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps { interface ModelConvertProps {
model: CheckpointModel; model: CheckpointModelConfig;
} }
export default function ModelConvert(props: ModelConvertProps) { export default function ModelConvert(props: ModelConvertProps) {
@ -26,6 +21,8 @@ export default function ModelConvert(props: ModelConvertProps) {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same'); const [saveLocation, setSaveLocation] = useState<string>('same');
const [customSaveLocation, setCustomSaveLocation] = useState<string>(''); const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
@ -38,20 +35,39 @@ export default function ModelConvert(props: ModelConvertProps) {
}; };
const modelConvertHandler = () => { const modelConvertHandler = () => {
const modelToConvert = { const responseBody = {
model_name: model, base_model: model.base_model,
save_location: saveLocation, model_name: model.model_name,
custom_location:
saveLocation === 'custom' && customSaveLocation !== ''
? customSaveLocation
: null,
}; };
dispatch(convertToDiffusers(modelToConvert)); convertModel(responseBody)
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConverted')}: ${model.model_name}`,
status: 'success',
})
)
);
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConversionFailed')}: ${
model.model_name
}`,
status: 'error',
})
)
);
});
}; };
return ( return (
<IAIAlertDialog <IAIAlertDialog
title={`${t('modelManager.convert')} ${model.name}`} title={`${t('modelManager.convert')} ${model.model_name}`}
acceptCallback={modelConvertHandler} acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler} cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`} acceptButtonText={`${t('modelManager.convert')}`}
@ -60,6 +76,7 @@ export default function ModelConvert(props: ModelConvertProps) {
size={'sm'} size={'sm'}
aria-label={t('modelManager.convertToDiffusers')} aria-label={t('modelManager.convertToDiffusers')}
className=" modal-close-btn" className=" modal-close-btn"
isLoading={isLoading}
> >
🧨 {t('modelManager.convertToDiffusers')} 🧨 {t('modelManager.convertToDiffusers')}
</IAIButton> </IAIButton>
@ -77,7 +94,7 @@ export default function ModelConvert(props: ModelConvertProps) {
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text> <Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex> </Flex>
<Flex flexDir="column" gap={4}> {/* <Flex flexDir="column" gap={4}>
<Flex marginTop={4} flexDir="column" gap={2}> <Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="600"> <Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')} {t('modelManager.convertToDiffusersSaveLocation')}
@ -103,9 +120,9 @@ export default function ModelConvert(props: ModelConvertProps) {
</Radio> </Radio>
</Flex> </Flex>
</RadioGroup> </RadioGroup>
</Flex> </Flex> */}
{saveLocation === 'custom' && ( {/* {saveLocation === 'custom' && (
<Flex flexDirection="column" rowGap={2}> <Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext"> <Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.customSaveLocation')} {t('modelManager.customSaveLocation')}
@ -119,8 +136,7 @@ export default function ModelConvert(props: ModelConvertProps) {
width="full" width="full"
/> />
</Flex> </Flex>
)} )} */}
</Flex>
</IAIAlertDialog> </IAIAlertDialog>
); );
} }

View File

@ -1,185 +1,46 @@
import { Box, Flex, Spinner, Text } from '@chakra-ui/react'; import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import { forEach } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { useTranslation } from 'react-i18next'; type ModelListProps = {
selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void;
};
import type { ChangeEvent, ReactNode } from 'react'; type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
import React, { useMemo, useState, useTransition } from 'react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
function ModelFilterButton({
label,
isActive,
onClick,
}: {
label: string;
isActive: boolean;
onClick: () => void;
}) {
return (
<IAIButton
onClick={onClick}
isActive={isActive}
sx={{
_active: {
bg: 'accent.750',
},
}}
size="sm"
>
{label}
</IAIButton>
);
}
const ModelList = () => {
const { data: mainModels } = useGetMainModelsQuery();
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
React.useEffect(() => {
const timer = setTimeout(() => {
setRenderModelList(true);
}, 200);
return () => clearTimeout(timer);
}, []);
const [searchText, setSearchText] = useState<string>('');
const [isSelectedFilter, setIsSelectedFilter] = useState<
'all' | 'ckpt' | 'diffusers'
>('all');
const [_, startTransition] = useTransition();
const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation(); const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('all');
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => { const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
startTransition(() => { selectFromResult: ({ data }) => ({
setSearchText(e.target.value); filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
}); }),
};
const renderModelListItems = useMemo(() => {
const ckptModelListItemsToRender: ReactNode[] = [];
const diffusersModelListItemsToRender: ReactNode[] = [];
const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = [];
if (!mainModels) return;
const modelList = mainModels.entities;
Object.keys(modelList).forEach((model, i) => {
if (
modelList[model]?.model_name
.toLowerCase()
.includes(searchText.toLowerCase())
) {
filteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
if (modelList[model]?.model_format === isSelectedFilter) {
localFilteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
}
}
if (modelList[model]?.model_format !== 'diffusers') {
ckptModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
} else {
diffusersModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
}
}); });
return searchText !== '' ? ( const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
isSelectedFilter === 'all' ? ( selectFromResult: ({ data }) => ({
<Box marginTop={4}>{filteredModelListItemsToRender}</Box> filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
) : ( }),
<Box marginTop={4}>{localFilteredModelListItemsToRender}</Box> });
)
) : (
<Flex flexDirection="column" rowGap={6}>
{isSelectedFilter === 'all' && (
<>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
my: 4,
mx: 0,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.checkpointModels')}
</Text>
{ckptModelListItemsToRender}
</Box>
</>
)}
{isSelectedFilter === 'diffusers' && ( const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
<Flex flexDirection="column" marginTop={4}> setNameFilter(e.target.value);
{diffusersModelListItemsToRender} }, []);
</Flex>
)}
{isSelectedFilter === 'ckpt' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
</Flex>
);
}, [mainModels, searchText, t, isSelectedFilter]);
return ( return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%"> <Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
@ -187,7 +48,6 @@ const ModelList = () => {
onChange={handleSearchFilter} onChange={handleSearchFilter}
label={t('modelManager.search')} label={t('modelManager.search')}
/> />
<Flex <Flex
flexDirection="column" flexDirection="column"
gap={4} gap={4}
@ -195,34 +55,60 @@ const ModelList = () => {
overflow="scroll" overflow="scroll"
paddingInlineEnd={4} paddingInlineEnd={4}
> >
<Flex columnGap={2}> <ButtonGroup isAttached>
<ModelFilterButton <IAIButton
label={t('modelManager.allModels')} onClick={() => setModelFormatFilter('all')}
onClick={() => setIsSelectedFilter('all')} isChecked={modelFormatFilter === 'all'}
isActive={isSelectedFilter === 'all'} size="sm"
/>
<ModelFilterButton
label={t('modelManager.diffusersModels')}
onClick={() => setIsSelectedFilter('diffusers')}
isActive={isSelectedFilter === 'diffusers'}
/>
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('ckpt')}
isActive={isSelectedFilter === 'ckpt'}
/>
</Flex>
{renderModelList ? (
renderModelListItems
) : (
<Flex
width="100%"
minHeight={96}
justifyContent="center"
alignItems="center"
> >
<Spinner /> {t('modelManager.allModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('diffusers')}
isChecked={modelFormatFilter === 'diffusers'}
>
{t('modelManager.diffusersModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('checkpoint')}
isChecked={modelFormatFilter === 'checkpoint'}
>
{t('modelManager.checkpointModels')}
</IAIButton>
</ButtonGroup>
{['all', 'diffusers'].includes(modelFormatFilter) &&
filteredDiffusersModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Diffusers
</Text>
{filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
filteredCheckpointModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Checkpoint
</Text>
{filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex> </Flex>
)} )}
</Flex> </Flex>
@ -231,3 +117,27 @@ const ModelList = () => {
}; };
export default ModelList; export default ModelList;
const modelsFilter = (
data: EntityState<MainModelConfigEntity> | undefined,
model_format: ModelFormat,
nameFilter: string
) => {
const filteredModels: MainModelConfigEntity[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.model_name
.toLowerCase()
.includes(nameFilter.toLowerCase());
const matchesFormat = model.model_format === model_format;
if (matchesFilter && matchesFormat) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -1,78 +1,71 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; import { DeleteIcon } from '@chakra-ui/icons';
import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; import { Flex, Text, Tooltip } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { setOpenModel } from 'features/system/store/systemSlice'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useDeleteMainModelsMutation,
} from 'services/api/endpoints/models';
type ModelListItemProps = { type ModelListItemProps = {
modelKey: string; model: MainModelConfigEntity;
name: string; isSelected: boolean;
description: string | undefined; setSelectedModelId: (v: string | undefined) => void;
}; };
export default function ModelListItem(props: ModelListItemProps) { export default function ModelListItem(props: ModelListItemProps) {
const { isProcessing, isConnected } = useAppSelector( const isBusy = useAppSelector(selectIsBusy);
(state: RootState) => state.system
);
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const { t } = useTranslation(); const { t } = useTranslation();
const [deleteMainModel] = useDeleteMainModelsMutation();
const dispatch = useAppDispatch(); const { model, isSelected, setSelectedModelId } = props;
const { modelKey, name, description } = props; const handleSelectModel = useCallback(() => {
setSelectedModelId(model.id);
}, [model.id, setSelectedModelId]);
const openModelHandler = () => { const handleModelDelete = useCallback(() => {
dispatch(setOpenModel(modelKey)); deleteMainModel(model);
}; setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId]);
const handleModelDelete = () => {
dispatch(deleteModel(modelKey));
dispatch(setOpenModel(null));
};
return ( return (
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
<Flex <Flex
alignItems="center" as={IAIButton}
p={2} isChecked={isSelected}
borderRadius="base" sx={{
sx={ justifyContent: 'start',
modelKey === openModel p: 2,
? { borderRadius: 'base',
bg: 'accent.750', w: 'full',
alignItems: 'center',
bg: isSelected ? 'accent.400' : 'base.100',
color: isSelected ? 'base.50' : 'base.800',
_hover: { _hover: {
bg: 'accent.750', bg: isSelected ? 'accent.500' : 'base.200',
color: isSelected ? 'base.50' : 'base.800',
}, },
} _dark: {
: { color: isSelected ? 'base.50' : 'base.100',
bg: isSelected ? 'accent.600' : 'base.850',
_hover: { _hover: {
bg: 'base.750', color: isSelected ? 'base.50' : 'base.100',
bg: isSelected ? 'accent.550' : 'base.800',
}, },
} },
} }}
onClick={handleSelectModel}
> >
<Box onClick={openModelHandler} cursor="pointer"> <Tooltip label={model.description} hasArrow placement="bottom">
<Tooltip label={description} hasArrow placement="bottom"> <Text sx={{ fontWeight: 500 }}>{model.model_name}</Text>
<Text fontWeight="600">{name}</Text>
</Tooltip> </Tooltip>
</Box> </Flex>
<Spacer onClick={openModelHandler} cursor="pointer" />
<Flex gap={2} alignItems="center">
<IAIIconButton
icon={<EditIcon />}
size="sm"
onClick={openModelHandler}
aria-label={t('accessibility.modifyConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected}
/>
<IAIAlertDialog <IAIAlertDialog
title={t('modelManager.deleteModel')} title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete} acceptCallback={handleModelDelete}
@ -80,9 +73,8 @@ export default function ModelListItem(props: ModelListItemProps) {
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
icon={<DeleteIcon />} icon={<DeleteIcon />}
size="sm"
aria-label={t('modelManager.deleteConfig')} aria-label={t('modelManager.deleteConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected} isDisabled={isBusy}
colorScheme="error" colorScheme="error"
/> />
} }
@ -93,6 +85,5 @@ export default function ModelListItem(props: ModelListItemProps) {
</Flex> </Flex>
</IAIAlertDialog> </IAIAlertDialog>
</Flex> </Flex>
</Flex>
); );
} }

View File

@ -9,17 +9,15 @@ import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useLayoutEffect } from 'react';
import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBeta';
import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta';
import { ImageDTO } from 'services/api/types';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { import {
CanvasInitialImageDropData, CanvasInitialImageDropData,
isValidDrop, isValidDrop,
useDroppable, useDroppable,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { memo, useLayoutEffect } from 'react';
import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta';
import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBeta';
const selector = createSelector( const selector = createSelector(
[canvasSelector, uiSelector], [canvasSelector, uiSelector],

View File

@ -0,0 +1,140 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
export const useMantineMultiSelectStyles = () => {
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useCallback(
() => ({
label: {
color: mode(base700, base300)(colorMode),
},
separatorLabel: {
color: mode(base500, base500)(colorMode),
'::after': { borderTopColor: mode(base300, base700)(colorMode) },
},
searchInput: {
':placeholder': {
color: mode(base300, base700)(colorMode),
},
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: mode(base900, base100)(colorMode),
},
},
}),
[
accent200,
accent300,
accent400,
accent500,
accent600,
base100,
base200,
base300,
base400,
base50,
base500,
base600,
base700,
base800,
base900,
boxShadow,
colorMode,
]
);
return styles;
};

View File

@ -0,0 +1,134 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
export const useMantineSelectStyles = () => {
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useCallback(
() => ({
label: {
color: mode(base700, base300)(colorMode),
},
separatorLabel: {
color: mode(base500, base500)(colorMode),
'::after': { borderTopColor: mode(base300, base700)(colorMode) },
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base100, base900)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,
button: {
color: mode(base900, base100)(colorMode),
},
},
}),
[
accent200,
accent300,
accent400,
accent500,
accent600,
base100,
base200,
base300,
base400,
base50,
base500,
base600,
base700,
base800,
base900,
boxShadow,
colorMode,
]
);
return styles;
};

View File

@ -1,6 +1,9 @@
import { MantineThemeOverride } from '@mantine/core'; import { MantineThemeOverride } from '@mantine/core';
import { useMemo } from 'react';
export const mantineTheme: MantineThemeOverride = { export const useMantineTheme = () => {
const mantineTheme: MantineThemeOverride = useMemo(
() => ({
colorScheme: 'dark', colorScheme: 'dark',
fontFamily: `'Inter Variable', sans-serif`, fontFamily: `'Inter Variable', sans-serif`,
components: { components: {
@ -20,4 +23,9 @@ export const mantineTheme: MantineThemeOverride = {
}, },
}, },
}, },
}),
[]
);
return mantineTheme;
}; };

View File

@ -25,11 +25,12 @@ export const boardImagesApi = api.injectEndpoints({
query: ({ board_id, offset, limit }) => ({ query: ({ board_id, offset, limit }) => ({
url: `board_images/${board_id}`, url: `board_images/${board_id}`,
method: 'GET', method: 'GET',
}), }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
// any list of boardimages // any list of boardimages
const tags: ApiFullTagDescription[] = [{ id: 'BoardImage', type: `${arg.board_id}_${LIST_TAG}` }]; const tags: ApiFullTagDescription[] = [
{ type: 'BoardImage', id: `${arg.board_id}_${LIST_TAG}` },
];
if (result) { if (result) {
// and individual tags for each boardimage // and individual tags for each boardimage
@ -57,7 +58,7 @@ export const boardImagesApi = api.injectEndpoints({
}), }),
invalidatesTags: (result, error, arg) => [ invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' }, { type: 'BoardImage' },
{ type: 'Board', id: arg.board_id } { type: 'Board', id: arg.board_id },
], ],
}), }),
@ -69,7 +70,7 @@ export const boardImagesApi = api.injectEndpoints({
}), }),
invalidatesTags: (result, error, arg) => [ invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' }, { type: 'BoardImage' },
{ type: 'Board', id: arg.board_id } { type: 'Board', id: arg.board_id },
], ],
}), }),
}), }),

View File

@ -20,7 +20,7 @@ export const boardsApi = api.injectEndpoints({
query: (arg) => ({ url: 'boards/', params: arg }), query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
// any list of boards // any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }]; const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) { if (result) {
// and individual tags for each board // and individual tags for each board
@ -43,7 +43,7 @@ export const boardsApi = api.injectEndpoints({
}), }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
// any list of boards // any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }]; const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) { if (result) {
// and individual tags for each board // and individual tags for each board
@ -69,7 +69,7 @@ export const boardsApi = api.injectEndpoints({
method: 'POST', method: 'POST',
params: { board_name }, params: { board_name },
}), }),
invalidatesTags: [{ id: 'Board', type: LIST_TAG }], invalidatesTags: [{ type: 'Board', id: LIST_TAG }],
}), }),
updateBoard: build.mutation<BoardDTO, UpdateBoardArg>({ updateBoard: build.mutation<BoardDTO, UpdateBoardArg>({
@ -87,8 +87,15 @@ export const boardsApi = api.injectEndpoints({
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }], invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }],
}), }),
deleteBoardAndImages: build.mutation<void, string>({ deleteBoardAndImages: build.mutation<void, string>({
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE', params: { include_images: true } }), query: (board_id) => ({
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }, { type: 'Image', id: LIST_TAG }], url: `boards/${board_id}`,
method: 'DELETE',
params: { include_images: true },
}),
invalidatesTags: (result, error, arg) => [
{ type: 'Board', id: arg },
{ type: 'Image', id: LIST_TAG },
],
}), }),
}), }),
}); });
@ -99,5 +106,5 @@ export const {
useCreateBoardMutation, useCreateBoardMutation,
useUpdateBoardMutation, useUpdateBoardMutation,
useDeleteBoardMutation, useDeleteBoardMutation,
useDeleteBoardAndImagesMutation useDeleteBoardAndImagesMutation,
} = boardsApi; } = boardsApi;

View File

@ -2,16 +2,27 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { cloneDeep } from 'lodash-es'; import { cloneDeep } from 'lodash-es';
import { import {
AnyModelConfig, AnyModelConfig,
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig, ControlNetModelConfig,
DiffusersModelConfig,
LoRAModelConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
MergeModelConfig,
TextualInversionModelConfig, TextualInversionModelConfig,
VaeModelConfig, VaeModelConfig,
} from 'services/api/types'; } from 'services/api/types';
import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
export type MainModelConfigEntity = MainModelConfig & { id: string }; export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity =
| DiffusersModelConfigEntity
| CheckpointModelConfigEntity;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
@ -32,6 +43,38 @@ type AnyModelConfigEntity =
| TextualInversionModelConfigEntity | TextualInversionModelConfigEntity
| VaeModelConfigEntity; | VaeModelConfigEntity;
type UpdateMainModelArg = {
base_model: BaseModelType;
model_name: string;
body: MainModelConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type DeleteMainModelArg = {
base_model: BaseModelType;
model_name: string;
};
type DeleteMainModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
};
type ConvertMainModelResponse =
paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json'];
type MergeMainModelArg = {
base_model: BaseModelType;
body: MergeModelConfig;
};
type MergeMainModelResponse =
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
@ -76,7 +119,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'main' } }), query: () => ({ url: 'models/', params: { model_type: 'main' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'MainModel', type: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -104,11 +147,58 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
updateMainModels: build.mutation<
UpdateMainModelResponse,
UpdateMainModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
DeleteMainModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
convertMainModels: build.mutation<
ConvertMainModelResponse,
ConvertMainModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/convert/${base_model}/main/${model_name}`,
method: 'PUT',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
return {
url: `models/merge/${base_model}`,
method: 'PUT',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }), query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'LoRAModel', type: LIST_TAG }, { type: 'LoRAModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -143,7 +233,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'ControlNetModel', type: LIST_TAG }, { type: 'ControlNetModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -175,7 +265,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }), query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'VaeModel', type: LIST_TAG }, { type: 'VaeModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -210,7 +300,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'TextualInversionModel', type: LIST_TAG }, { type: 'TextualInversionModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -247,4 +337,8 @@ export const {
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery, useGetTextualInversionModelsQuery,
useGetVaeModelsQuery, useGetVaeModelsQuery,
useUpdateMainModelsMutation,
useDeleteMainModelsMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
} = modelsApi; } = modelsApi;

View File

@ -3290,7 +3290,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -4605,18 +4605,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion2ModelFormat * StableDiffusion2ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;
@ -4997,7 +4997,7 @@ export type operations = {
/** @description The model imported successfully */ /** @description The model imported successfully */
201: { 201: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description The model could not be found */ /** @description The model could not be found */
@ -5065,14 +5065,14 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
responses: { responses: {
/** @description The model was updated successfully */ /** @description The model was updated successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description Bad request */ /** @description Bad request */
@ -5106,7 +5106,7 @@ export type operations = {
/** @description Model converted successfully */ /** @description Model converted successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description Bad request */ /** @description Bad request */
@ -5141,7 +5141,7 @@ export type operations = {
/** @description Model converted successfully */ /** @description Model converted successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description Incompatible models */ /** @description Incompatible models */

View File

@ -42,17 +42,20 @@ export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig']; components['schemas']['ControlNetModelConfig'];
export type TextualInversionModelConfig = export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig']; components['schemas']['TextualInversionModelConfig'];
export type MainModelConfig = export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion1ModelDiffusersConfig'] | components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig']; | components['schemas']['StableDiffusion2ModelDiffusersConfig'];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig = export type AnyModelConfig =
| LoRAModelConfig | LoRAModelConfig
| VaeModelConfig | VaeModelConfig
| ControlNetModelConfig | ControlNetModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig; | MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models'];
// Graphs // Graphs
export type Graph = components['schemas']['Graph']; export type Graph = components['schemas']['Graph'];

View File

@ -19,16 +19,8 @@ const invokeAI = defineStyle((props) => {
bg: mode('base.200', 'base.600')(props), bg: mode('base.200', 'base.600')(props),
color: mode('base.850', 'base.100')(props), color: mode('base.850', 'base.100')(props),
borderRadius: 'base', borderRadius: 'base',
textShadow: mode(
'0 0 0.3rem var(--invokeai-colors-base-50)',
'0 0 0.3rem var(--invokeai-colors-base-900)'
)(props),
svg: { svg: {
fill: mode('base.850', 'base.100')(props), fill: mode('base.850', 'base.100')(props),
filter: mode(
'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-100))',
'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-800))'
)(props),
}, },
_hover: { _hover: {
bg: mode('base.300', 'base.500')(props), bg: mode('base.300', 'base.500')(props),
@ -57,16 +49,8 @@ const invokeAI = defineStyle((props) => {
bg: mode(`${c}.400`, `${c}.600`)(props), bg: mode(`${c}.400`, `${c}.600`)(props),
color: mode(`base.50`, `base.100`)(props), color: mode(`base.50`, `base.100`)(props),
borderRadius: 'base', borderRadius: 'base',
textShadow: mode(
`0 0 0.3rem var(--invokeai-colors-${c}-600)`,
`0 0 0.3rem var(--invokeai-colors-${c}-800)`
)(props),
svg: { svg: {
fill: mode(`base.50`, `base.100`)(props), fill: mode(`base.50`, `base.100`)(props),
filter: mode(
`drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-600))`,
`drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-800))`
)(props),
}, },
_disabled, _disabled,
_hover: { _hover: {

View File

@ -2,7 +2,7 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react';
import { mode } from '@chakra-ui/theme-tools'; import { mode } from '@chakra-ui/theme-tools';
const subtext = defineStyle((props) => ({ const subtext = defineStyle((props) => ({
color: mode('colors.base.500', 'colors.base.400')(props), color: mode('base.500', 'base.400')(props),
})); }));
export const textTheme = defineStyleConfig({ export const textTheme = defineStyleConfig({