Merge branch 'main' into fix/nodes/fix-mouse-interactions

This commit is contained in:
blessedcoolant 2023-07-15 04:13:46 +12:00 committed by GitHub
commit 48561908b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 1270 additions and 981 deletions

View File

@ -589,6 +589,7 @@ class ModelManager(object):
rmtree(str(model_path)) rmtree(str(model_path))
else: else:
model_path.unlink() model_path.unlink()
self.commit()
# LS: tested # LS: tested
def add_model( def add_model(

View File

@ -343,6 +343,7 @@
"safetensorModels": "SafeTensors", "safetensorModels": "SafeTensors",
"modelAdded": "Model Added", "modelAdded": "Model Added",
"modelUpdated": "Model Updated", "modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted", "modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces", "cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New", "addNew": "Add New",
@ -397,8 +398,8 @@
"delete": "Delete", "delete": "Delete",
"deleteModel": "Delete Model", "deleteModel": "Delete Model",
"deleteConfig": "Delete Config", "deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?", "deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.", "deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location", "formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.", "formMessageDiffusersModelLocationDesc": "Please enter at least one.",
"formMessageDiffusersVAELocation": "VAE Location", "formMessageDiffusersVAELocation": "VAE Location",
@ -409,7 +410,7 @@
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.", "convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.", "convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.", "convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.", "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location", "convertToDiffusersSaveLocation": "Save Location",
"v1": "v1", "v1": "v1",
@ -420,12 +421,14 @@
"pathToCustomConfig": "Path To Custom Config", "pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting", "statusConverting": "Converting",
"modelConverted": "Model Converted", "modelConverted": "Model Converted",
"modelConversionFailed": "Model Conversion Failed",
"sameFolder": "Same folder", "sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder", "invokeRoot": "InvokeAI folder",
"custom": "Custom", "custom": "Custom",
"customSaveLocation": "Custom Save Location", "customSaveLocation": "Custom Save Location",
"merge": "Merge", "merge": "Merge",
"modelsMerged": "Models Merged", "modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models", "mergeModels": "Merge Models",
"modelOne": "Model 1", "modelOne": "Model 1",
"modelTwo": "Model 2", "modelTwo": "Model 2",
@ -446,7 +449,8 @@
"weightedSum": "Weighted Sum", "weightedSum": "Weighted Sum",
"none": "none", "none": "none",
"addDifference": "Add Difference", "addDifference": "Add Difference",
"pickModelType": "Pick Model Type" "pickModelType": "Pick Model Type",
"selectModel": "Select Model"
}, },
"parameters": { "parameters": {
"general": "General", "general": "General",
@ -599,7 +603,6 @@
"nodesLoaded": "Nodes Loaded", "nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes", "nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared" "nodesCleared": "Nodes Cleared"
}, },
"tooltip": { "tooltip": {
"feature": { "feature": {

View File

@ -9,9 +9,9 @@ import { theme as invokeAITheme } from 'theme/theme';
import '@fontsource-variable/inter'; import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core'; import { MantineProvider } from '@mantine/core';
import { mantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css'; import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css'; import 'theme/css/overlayscrollbars.css';
import { useMantineTheme } from 'mantine-theme/theme';
type ThemeLocaleProviderProps = { type ThemeLocaleProviderProps = {
children: ReactNode; children: ReactNode;
@ -35,8 +35,10 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
document.body.dir = direction; document.body.dir = direction;
}, [direction]); }, [direction]);
const mantineTheme = useMantineTheme();
return ( return (
<MantineProvider withGlobalStyles theme={mantineTheme}> <MantineProvider theme={mantineTheme}>
<ChakraProvider theme={theme} colorModeManager={manager}> <ChakraProvider theme={theme} colorModeManager={manager}>
{children} {children}
</ChakraProvider> </ChakraProvider>

View File

@ -1,6 +1,7 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions'; import { modelsApi } from 'services/api/endpoints/models';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { startAppListening } from '../..'; import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' }); const moduleLog = log.child({ namespace: 'socketio' });
@ -23,6 +24,13 @@ export const addSocketConnectedEventListener = () => {
// pass along the socket event as an application action // pass along the socket event as an application action
dispatch(appSocketConnected(action.payload)); dispatch(appSocketConnected(action.payload));
// update all server state
dispatch(modelsApi.endpoints.getMainModels.initiate());
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
dispatch(modelsApi.endpoints.getVaeModels.initiate());
}, },
}); });
}; };

View File

@ -100,10 +100,11 @@ export const store = configureStore({
// manually type state, cannot type the arg // manually type state, cannot type the arg
// const typedState = state as ReturnType<typeof rootReducer>; // const typedState = state as ReturnType<typeof rootReducer>;
if (action.type.startsWith('api/')) { // TODO: doing this breaks the rtk query devtools, commenting out for now
// don't log api actions, with manual cache updates they are extremely noisy // if (action.type.startsWith('api/')) {
return false; // // don't log api actions, with manual cache updates they are extremely noisy
} // return false;
// }
if (actionsDenylist.includes(action.type)) { if (actionsDenylist.includes(action.type)) {
// don't log other noisy actions // don't log other noisy actions

View File

@ -1,11 +1,13 @@
import { Flex } from '@chakra-ui/react'; import { Flex, useColorMode } from '@chakra-ui/react';
import { ReactElement } from 'react'; import { ReactElement } from 'react';
import { mode } from 'theme/util/mode';
export function IAIFormItemWrapper({ export function IAIFormItemWrapper({
children, children,
}: { }: {
children: ReactElement | ReactElement[]; children: ReactElement | ReactElement[];
}) { }) {
const { colorMode } = useColorMode();
return ( return (
<Flex <Flex
sx={{ sx={{
@ -14,7 +16,7 @@ export function IAIFormItemWrapper({
rowGap: 4, rowGap: 4,
borderRadius: 'base', borderRadius: 'base',
width: 'full', width: 'full',
bg: 'base.900', bg: mode('base.100', 'base.900')(colorMode),
}} }}
> >
{children} {children}

View File

@ -0,0 +1,44 @@
import { useColorMode } from '@chakra-ui/react';
import { TextInput, TextInputProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { mode } from 'theme/util/mode';
type IAIMantineTextInputProps = TextInputProps;
export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
const { ...rest } = props;
const {
base50,
base100,
base200,
base300,
base800,
base700,
base900,
accent500,
accent300,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
return (
<TextInput
styles={() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
})}
{...rest}
/>
);
}

View File

@ -1,10 +1,9 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react'; import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & { type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string; tooltip?: string;
@ -14,25 +13,6 @@ type IAIMultiSelectProps = MultiSelectProps & {
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props; const { searchable = true, tooltip, inputRef, ...rest } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const { colorMode } = useColorMode();
const handleKeyDown = useCallback( const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => { (e: KeyboardEvent<HTMLInputElement>) => {
@ -52,6 +32,8 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
[dispatch] [dispatch]
); );
const styles = useMantineMultiSelectStyles();
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}> <Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect <MultiSelect
@ -60,92 +42,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
onKeyUp={handleKeyUp} onKeyUp={handleKeyUp}
searchable={searchable} searchable={searchable}
maxDropdownHeight={300} maxDropdownHeight={300}
styles={() => ({ styles={styles}
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
searchInput: {
':placeholder': {
color: mode(base300, base700)(colorMode),
},
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
{...rest} {...rest}
/> />
</Tooltip> </Tooltip>

View File

@ -0,0 +1,78 @@
import { Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
export type IAISelectDataType = {
value: string;
label: string;
tooltip?: string;
};
type IAISelectProps = SelectProps & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={styles}
{...rest}
/>
</Tooltip>
);
};
export default memo(IAIMantineSearchableSelect);

View File

@ -1,10 +1,7 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core'; import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { RefObject, memo } from 'react';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
import { mode } from 'theme/util/mode';
export type IAISelectDataType = { export type IAISelectDataType = {
value: string; value: string;
@ -18,157 +15,13 @@ type IAISelectProps = SelectProps & {
}; };
const IAIMantineSelect = (props: IAISelectProps) => { const IAIMantineSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props; const { tooltip, inputRef, ...rest } = props;
const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode(); const styles = useMantineSelectStyles();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const [boxShadow] = useToken('shadows', ['dark-lg']);
return ( return (
<Tooltip label={tooltip} placement="top" hasArrow> <Tooltip label={tooltip} placement="top" hasArrow>
<Select <Select ref={inputRef} styles={styles} {...rest} />
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={() => ({
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base100, base900)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
{...rest}
/>
</Tooltip> </Tooltip>
); );
}; };

View File

@ -24,7 +24,7 @@ import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { systemSelector } from 'features/system/store/systemSelectors'; import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { import {
canvasCopiedToClipboard, canvasCopiedToClipboard,
canvasDownloadedAsImage, canvasDownloadedAsImage,
@ -213,7 +213,7 @@ const IAICanvasToolbar = () => {
}} }}
> >
<Box w={24}> <Box w={24}>
<IAIMantineSelect <IAIMantineSearchableSelect
tooltip={`${t('unifiedCanvas.layer')} (Q)`} tooltip={`${t('unifiedCanvas.layer')} (Q)`}
value={layer} value={layer}
data={LAYER_NAMES_DICT} data={LAYER_NAMES_DICT}

View File

@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, { import IAIMantineSearchableSelect, {
IAISelectDataType, IAISelectDataType,
} from 'common/components/IAIMantineSelect'; } from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { import {
CONTROLNET_MODELS, CONTROLNET_MODELS,
@ -48,7 +48,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
data={controlNetModels} data={controlNetModels}
value={model} value={model}
onChange={handleModelChanged} onChange={handleModelChanged}

View File

@ -1,8 +1,12 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, { import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType, IAISelectDataType,
} from 'common/components/IAIMantineSelect'; } from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants'; import { CONTROLNET_PROCESSORS } from '../../store/constants';
@ -11,10 +15,6 @@ import {
ControlNetProcessorNode, ControlNetProcessorNode,
ControlNetProcessorType, ControlNetProcessorType,
} from '../../store/types'; } from '../../store/types';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { createSelector } from '@reduxjs/toolkit';
import { configSelector } from 'features/system/store/configSelectors';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
type ParamControlNetProcessorSelectProps = { type ParamControlNetProcessorSelectProps = {
controlNetId: string; controlNetId: string;
@ -72,7 +72,7 @@ const ParamControlNetProcessorSelect = (
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
label="Processor" label="Processor"
value={processorNode.type ?? 'canny_image_processor'} value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors} data={controlNetProcessors}

View File

@ -9,7 +9,7 @@ import {
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
@ -106,7 +106,7 @@ const ParamEmbeddingPopover = (props: Props) => {
</Text> </Text>
</Flex> </Flex>
) : ( ) : (
<IAIMantineSelect <IAIMantineSearchableSelect
inputRef={inputRef} inputRef={inputRef}
autoFocus autoFocus
placeholder={'Add Embedding'} placeholder={'Add Embedding'}

View File

@ -12,10 +12,10 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { memo, useContext, useRef, useState } from 'react'; import { memo, useContext, useRef, useState } from 'react';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
const UpdateImageBoardModal = () => { const UpdateImageBoardModal = () => {
// const boards = useSelector(selectBoardsAll); // const boards = useSelector(selectBoardsAll);
@ -56,7 +56,7 @@ const UpdateImageBoardModal = () => {
{isFetching ? ( {isFetching ? (
<Spinner /> <Spinner />
) : ( ) : (
<IAIMantineSelect <IAIMantineSearchableSelect
placeholder="Select Board" placeholder="Select Board"
onChange={(v) => setSelectedBoard(v)} onChange={(v) => setSelectedBoard(v)}
value={selectedBoard} value={selectedBoard}

View File

@ -4,7 +4,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { RootState, stateSelector } from 'app/store/store'; import { RootState, stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { loraAdded } from 'features/lora/store/loraSlice'; import { loraAdded } from 'features/lora/store/loraSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
@ -84,7 +84,7 @@ const ParamLoRASelect = () => {
} }
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'} placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
value={null} value={null}
data={data} data={data}

View File

@ -3,7 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { forwardRef, useCallback } from 'react'; import { forwardRef, useCallback } from 'react';
import 'reactflow/dist/style.css'; import 'reactflow/dist/style.css';
@ -77,7 +77,7 @@ const AddNodeMenu = () => {
return ( return (
<Flex sx={{ gap: 2, alignItems: 'center' }}> <Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAIMantineSelect <IAIMantineSearchableSelect
selectOnBlur={false} selectOnBlur={false}
placeholder="Add Node" placeholder="Add Node"
value={null} value={null}

View File

@ -1,7 +1,7 @@
import { Flex, Text } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
@ -89,7 +89,7 @@ const LoRAModelInputFieldComponent = (
} }
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
value={selectedLoRAModel?.id ?? null} value={selectedLoRAModel?.id ?? null}
label={ label={
selectedLoRAModel?.base_model && selectedLoRAModel?.base_model &&

View File

@ -6,7 +6,7 @@ import {
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
@ -81,14 +81,14 @@ const ModelInputFieldComponent = (
); );
return isLoading ? ( return isLoading ? (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('modelManager.model')} label={t('modelManager.model')}
placeholder="Loading..." placeholder="Loading..."
disabled={true} disabled={true}
data={[]} data={[]}
/> />
) : ( ) : (
<IAIMantineSelect <IAIMantineSearchableSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={ label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model] selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]

View File

@ -1,6 +1,6 @@
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
@ -86,7 +86,7 @@ const VaeModelInputFieldComponent = (
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description} tooltip={selectedVaeModel?.description}
label={ label={

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice'; import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice';
import { import {
@ -35,7 +35,7 @@ const ParamScaleBeforeProcessing = () => {
}; };
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('parameters.scaleBeforeProcessing')} label={t('parameters.scaleBeforeProcessing')}
data={BOUNDING_BOX_SCALES_DICT} data={BOUNDING_BOX_SCALES_DICT}
value={boundingBoxScale} value={boundingBoxScale}

View File

@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants'; import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice'; import { setScheduler } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
@ -48,7 +48,7 @@ const ParamScheduler = () => {
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('parameters.scheduler')} label={t('parameters.scheduler')}
value={scheduler} value={scheduler}
data={data} data={data}

View File

@ -1,7 +1,7 @@
import { FACETOOL_TYPES } from 'app/constants'; import { FACETOOL_TYPES } from 'app/constants';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { import {
FacetoolType, FacetoolType,
setFacetoolType, setFacetoolType,
@ -20,7 +20,7 @@ export default function FaceRestoreType() {
dispatch(setFacetoolType(v as FacetoolType)); dispatch(setFacetoolType(v as FacetoolType));
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('parameters.type')} label={t('parameters.type')}
data={FACETOOL_TYPES.concat()} data={FACETOOL_TYPES.concat()}
value={facetoolType} value={facetoolType}

View File

@ -2,7 +2,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
@ -77,14 +77,14 @@ const ParamMainModelSelect = () => {
); );
return isLoading ? ( return isLoading ? (
<IAIMantineSelect <IAIMantineSearchableSelect
label={t('modelManager.model')} label={t('modelManager.model')}
placeholder="Loading..." placeholder="Loading..."
disabled={true} disabled={true}
data={[]} data={[]}
/> />
) : ( ) : (
<IAIMantineSelect <IAIMantineSearchableSelect
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
label={t('modelManager.model')} label={t('modelManager.model')}
value={selectedModel?.id} value={selectedModel?.id}

View File

@ -1,7 +1,7 @@
import { UPSCALING_LEVELS } from 'app/constants'; import { UPSCALING_LEVELS } from 'app/constants';
import type { RootState } from 'app/store/store'; import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { import {
UpscalingLevel, UpscalingLevel,
setUpscalingLevel, setUpscalingLevel,
@ -24,7 +24,7 @@ export default function UpscaleScale() {
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel)); dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
disabled={!isESRGANAvailable} disabled={!isESRGANAvailable}
label={t('parameters.scale')} label={t('parameters.scale')}
value={String(upscalingLevel)} value={String(upscalingLevel)}

View File

@ -2,7 +2,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core'; import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
@ -92,7 +92,7 @@ const ParamVAEModelSelect = () => {
); );
return ( return (
<IAIMantineSelect <IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip} itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description} tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')} label={t('modelManager.vae')}

View File

@ -13,7 +13,6 @@ export const systemPersistDenylist: (keyof SystemState)[] = [
'isProcessing', 'isProcessing',
'totalIterations', 'totalIterations',
'totalSteps', 'totalSteps',
'openModel',
'isCancelScheduled', 'isCancelScheduled',
'progressImage', 'progressImage',
'wereModelsReceived', 'wereModelsReceived',

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { reduce, pickBy } from 'lodash-es'; import { pickBy, reduce } from 'lodash-es';
export const systemSelector = (state: RootState) => state.system; export const systemSelector = (state: RootState) => state.system;
@ -50,3 +50,8 @@ export const languageSelector = createSelector(
export const isProcessingSelector = (state: RootState) => export const isProcessingSelector = (state: RootState) =>
state.system.isProcessing; state.system.isProcessing;
export const selectIsBusy = createSelector(
(state: RootState) => state,
(state) => state.system.isProcessing || !state.system.isConnected
);

View File

@ -46,7 +46,6 @@ export interface SystemState {
toastQueue: UseToastOptions[]; toastQueue: UseToastOptions[];
searchFolder: string | null; searchFolder: string | null;
foundModels: InvokeAI.FoundModel[] | null; foundModels: InvokeAI.FoundModel[] | null;
openModel: string | null;
/** /**
* The current progress image * The current progress image
*/ */
@ -109,7 +108,6 @@ export const initialSystemState: SystemState = {
toastQueue: [], toastQueue: [],
searchFolder: null, searchFolder: null,
foundModels: null, foundModels: null,
openModel: null,
progressImage: null, progressImage: null,
shouldAntialiasProgressImage: false, shouldAntialiasProgressImage: false,
sessionId: null, sessionId: null,
@ -164,9 +162,6 @@ export const systemSlice = createSlice({
) => { ) => {
state.foundModels = action.payload; state.foundModels = action.payload;
}, },
setOpenModel: (state, action: PayloadAction<string | null>) => {
state.openModel = action.payload;
},
/** /**
* A cancel was scheduled * A cancel was scheduled
*/ */
@ -433,7 +428,6 @@ export const {
clearToastQueue, clearToastQueue,
setSearchFolder, setSearchFolder,
setFoundModels, setFoundModels,
setOpenModel,
cancelScheduled, cancelScheduled,
scheduledCancelAborted, scheduledCancelAborted,
cancelTypeChanged, cancelTypeChanged,

View File

@ -13,7 +13,7 @@ type ModelManagerTabInfo = {
content: ReactNode; content: ReactNode;
}; };
const modelManagerTabs: ModelManagerTabInfo[] = [ const tabs: ModelManagerTabInfo[] = [
{ {
id: 'modelManager', id: 'modelManager',
label: i18n.t('modelManager.modelManager'), label: i18n.t('modelManager.modelManager'),
@ -31,49 +31,28 @@ const modelManagerTabs: ModelManagerTabInfo[] = [
}, },
]; ];
const renderTabsList = () => {
const modelManagerTabListsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabListsToRender.push(
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
);
});
return (
<TabList
sx={{
w: '100%',
color: 'base.200',
flexDirection: 'row',
borderBottomWidth: 2,
borderColor: 'accent.700',
}}
>
{modelManagerTabListsToRender}
</TabList>
);
};
const renderTabPanels = () => {
const modelManagerTabPanelsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabPanelsToRender.push(
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
);
});
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
};
const ModelManagerTab = () => { const ModelManagerTab = () => {
return ( return (
<Tabs <Tabs
isLazy isLazy
variant="invokeAI" variant="line"
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }} layerStyle="first"
sx={{ w: 'full', h: 'full', p: 4, gap: 4, borderRadius: 'base' }}
> >
{renderTabsList()} <TabList>
{renderTabPanels()} {tabs.map((tab) => (
<Tab sx={{ borderTopRadius: 'base' }} key={tab.id}>
{tab.label}
</Tab>
))}
</TabList>
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}>
{tabs.map((tab) => (
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
{tab.content}
</TabPanel>
))}
</TabPanels>
</Tabs> </Tabs>
); );
}; };

View File

@ -1,4 +1,4 @@
import { Divider, Flex } from '@chakra-ui/react'; import { Divider, Flex, useColorMode } from '@chakra-ui/react';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
@ -12,6 +12,8 @@ export default function AddModelsPanel() {
(state: RootState) => state.ui.addNewModelUIOption (state: RootState) => state.ui.addNewModelUIOption
); );
const { colorMode } = useColorMode();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
@ -20,27 +22,13 @@ export default function AddModelsPanel() {
<Flex columnGap={4}> <Flex columnGap={4}>
<IAIButton <IAIButton
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))} onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
sx={{ isChecked={addNewModelUIOption == 'ckpt'}
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600',
},
}}
> >
{t('modelManager.addCheckpointModel')} {t('modelManager.addCheckpointModel')}
</IAIButton> </IAIButton>
<IAIButton <IAIButton
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))} onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
sx={{ isChecked={addNewModelUIOption == 'diffusers'}
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600',
},
}}
> >
{t('modelManager.addDiffuserModel')} {t('modelManager.addDiffuserModel')}
</IAIButton> </IAIButton>

View File

@ -66,13 +66,13 @@ export default function AddDiffusersModel() {
}; };
return ( return (
<Flex overflow="scroll" maxHeight={window.innerHeight - 270}> <Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
<Formik <Formik
initialValues={addModelFormValues} initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler} onSubmit={addModelFormSubmitHandler}
> >
{({ handleSubmit, errors, touched }) => ( {({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}> <IAIForm onSubmit={handleSubmit} w="full">
<VStack rowGap={2}> <VStack rowGap={2}>
<IAIFormItemWrapper> <IAIFormItemWrapper>
{/* Name */} {/* Name */}
@ -90,7 +90,6 @@ export default function AddDiffusersModel() {
name="name" name="name"
type="text" type="text"
validate={baseValidation} validate={baseValidation}
width="2xl"
isRequired isRequired
/> />
{!!errors.name && touched.name ? ( {!!errors.name && touched.name ? (
@ -119,7 +118,6 @@ export default function AddDiffusersModel() {
id="description" id="description"
name="description" name="description"
type="text" type="text"
width="2xl"
isRequired isRequired
/> />
{!!errors.description && touched.description ? ( {!!errors.description && touched.description ? (
@ -153,13 +151,7 @@ export default function AddDiffusersModel() {
{t('modelManager.modelLocation')} {t('modelManager.modelLocation')}
</FormLabel> </FormLabel>
<VStack alignItems="start"> <VStack alignItems="start">
<Field <Field as={IAIInput} id="path" name="path" type="text" />
as={IAIInput}
id="path"
name="path"
type="text"
width="2xl"
/>
{!!errors.path && touched.path ? ( {!!errors.path && touched.path ? (
<FormErrorMessage>{errors.path}</FormErrorMessage> <FormErrorMessage>{errors.path}</FormErrorMessage>
) : ( ) : (
@ -181,7 +173,6 @@ export default function AddDiffusersModel() {
id="repo_id" id="repo_id"
name="repo_id" name="repo_id"
type="text" type="text"
width="2xl"
/> />
{!!errors.repo_id && touched.repo_id ? ( {!!errors.repo_id && touched.repo_id ? (
<FormErrorMessage>{errors.repo_id}</FormErrorMessage> <FormErrorMessage>{errors.repo_id}</FormErrorMessage>
@ -220,7 +211,6 @@ export default function AddDiffusersModel() {
id="vae.path" id="vae.path"
name="vae.path" name="vae.path"
type="text" type="text"
width="2xl"
/> />
{!!errors.vae?.path && touched.vae?.path ? ( {!!errors.vae?.path && touched.vae?.path ? (
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage> <FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
@ -245,7 +235,6 @@ export default function AddDiffusersModel() {
id="vae.repo_id" id="vae.repo_id"
name="vae.repo_id" name="vae.repo_id"
type="text" type="text"
width="2xl"
/> />
{!!errors.vae?.repo_id && touched.vae?.repo_id ? ( {!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage> <FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>

View File

@ -1,42 +1,85 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react'; import {
import { RootState } from 'app/store/store'; Flex,
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; Radio,
RadioGroup,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox'; import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider'; import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice';
import { pickBy } from 'lodash-es'; import { pickBy } from 'lodash-es';
import { useState } from 'react'; import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import {
useGetMainModelsQuery,
useMergeMainModelsMutation,
} from 'services/api/endpoints/models';
import { BaseModelType, MergeModelConfig } from 'services/api/types';
import { mode } from 'theme/util/mode';
const baseModelTypeSelectData = [
{ label: 'Stable Diffusion 1', value: 'sd-1' },
{ label: 'Stable Diffusion 2', value: 'sd-2' },
];
type MergeInterpolationMethods =
| 'weighted_sum'
| 'sigmoid'
| 'inv_sigmoid'
| 'add_difference';
export default function MergeModelsPanel() { export default function MergeModelsPanel() {
const { t } = useTranslation(); const { t } = useTranslation();
const { colorMode } = useColorMode();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery(); const { data } = useGetMainModelsQuery();
const diffusersModels = pickBy( const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
const sd1DiffusersModels = pickBy(
data?.entities, data?.entities,
(value, _) => value?.model_format === 'diffusers' (value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
); );
const [modelOne, setModelOne] = useState<string>( const sd2DiffusersModels = pickBy(
Object.keys(diffusersModels)[0] data?.entities,
(value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
); );
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1] const modelsMap = useMemo(() => {
return {
'sd-1': sd1DiffusersModels,
'sd-2': sd2DiffusersModels,
};
}, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel])[0]
); );
const [modelThree, setModelThree] = useState<string>('none'); const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel])[1]
);
const [modelThree, setModelThree] = useState<string | null>(null);
const [mergedModelName, setMergedModelName] = useState<string>(''); const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5); const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState< const [modelMergeInterp, setModelMergeInterp] =
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference' useState<MergeInterpolationMethods>('weighted_sum');
>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState< const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
'root' | 'custom' 'root' | 'custom'
@ -47,41 +90,73 @@ export default function MergeModelsPanel() {
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false); const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter( const modelOneList = Object.keys(
(model) => model !== modelTwo && model !== modelThree modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelTwo && model !== modelThree);
const modelTwoList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelThree);
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
(model) => model !== modelOne && model !== modelTwo
); );
const modelTwoList = Object.keys(diffusersModels).filter( const handleBaseModelChange = (v: string) => {
(model) => model !== modelOne && model !== modelThree setBaseModel(v as BaseModelType);
); setModelOne(null);
setModelTwo(null);
const modelThreeList = [ };
{ key: t('modelManager.none'), value: 'none' },
...Object.keys(diffusersModels)
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ key: model, value: model })),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const mergeModelsHandler = () => { const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree]; const models_names: string[] = [];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = { let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
models_to_merge: modelsToMerge, modelsToMerge = modelsToMerge.filter((model) => model !== null);
modelsToMerge.forEach((model) => {
if (model) {
models_names.push(model?.split('/')[2]);
}
});
const mergeModelsInfo: MergeModelConfig = {
model_names: models_names,
merged_model_name: merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'), mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha, alpha: modelMergeAlpha,
interp: modelMergeInterp, interp: modelMergeInterp,
model_merge_save_path: // model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc, // modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce, force: modelMergeForce,
}; };
dispatch(mergeDiffusersModels(mergeModelsInfo)); mergeModels({
base_model: baseModel,
body: mergeModelsInfo,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMerged'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMergeFailed'),
status: 'error',
})
)
);
}
});
}; };
return ( return (
@ -90,7 +165,6 @@ export default function MergeModelsPanel() {
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
rowGap: 1, rowGap: 1,
bg: 'base.900',
}} }}
> >
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text> <Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
@ -98,26 +172,43 @@ export default function MergeModelsPanel() {
{t('modelManager.modelMergeHeaderHelp2')} {t('modelManager.modelMergeHeaderHelp2')}
</Text> </Text>
</Flex> </Flex>
<Flex columnGap={4}> <Flex columnGap={4}>
<IAISelect <IAIMantineSelect
label="Model Type"
w="100%"
data={baseModelTypeSelectData}
value={baseModel}
onChange={handleBaseModelChange}
/>
<IAIMantineSearchableSelect
label={t('modelManager.modelOne')} label={t('modelManager.modelOne')}
validValues={modelOneList} w="100%"
onChange={(e) => setModelOne(e.target.value)} value={modelOne}
placeholder={t('modelManager.selectModel')}
data={modelOneList}
onChange={(v) => setModelOne(v)}
/> />
<IAISelect <IAIMantineSearchableSelect
label={t('modelManager.modelTwo')} label={t('modelManager.modelTwo')}
validValues={modelTwoList} w="100%"
onChange={(e) => setModelTwo(e.target.value)} placeholder={t('modelManager.selectModel')}
value={modelTwo}
data={modelTwoList}
onChange={(v) => setModelTwo(v)}
/> />
<IAISelect <IAIMantineSearchableSelect
label={t('modelManager.modelThree')} label={t('modelManager.modelThree')}
validValues={modelThreeList} data={modelThreeList}
onChange={(e) => { w="100%"
if (e.target.value !== 'none') { placeholder={t('modelManager.selectModel')}
setModelThree(e.target.value); clearable
onChange={(v) => {
if (!v) {
setModelThree(null);
setModelMergeInterp('add_difference'); setModelMergeInterp('add_difference');
} else { } else {
setModelThree('none'); setModelThree(v);
setModelMergeInterp('weighted_sum'); setModelMergeInterp('weighted_sum');
} }
}} }}
@ -136,7 +227,7 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: 'base.900', bg: mode('base.100', 'base.800')(colorMode),
}} }}
> >
<IAISlider <IAISlider
@ -161,7 +252,7 @@ export default function MergeModelsPanel() {
padding: 4, padding: 4,
borderRadius: 'base', borderRadius: 'base',
gap: 4, gap: 4,
bg: 'base.900', bg: mode('base.100', 'base.800')(colorMode),
}} }}
> >
<Text fontWeight={500} fontSize="sm" variant="subtext"> <Text fontWeight={500} fontSize="sm" variant="subtext">
@ -169,12 +260,10 @@ export default function MergeModelsPanel() {
</Text> </Text>
<RadioGroup <RadioGroup
value={modelMergeInterp} value={modelMergeInterp}
onChange={( onChange={(v: MergeInterpolationMethods) => setModelMergeInterp(v)}
v: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
) => setModelMergeInterp(v)}
> >
<Flex columnGap={4}> <Flex columnGap={4}>
{modelThree === 'none' ? ( {modelThree === null ? (
<> <>
<Radio value="weighted_sum"> <Radio value="weighted_sum">
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text> <Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
@ -199,7 +288,7 @@ export default function MergeModelsPanel() {
</RadioGroup> </RadioGroup>
</Flex> </Flex>
<Flex {/* <Flex
sx={{ sx={{
flexDirection: 'column', flexDirection: 'column',
padding: 4, padding: 4,
@ -235,7 +324,7 @@ export default function MergeModelsPanel() {
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)} onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/> />
)} )}
</Flex> </Flex> */}
<IAISimpleCheckbox <IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')} label={t('modelManager.ignoreMismatch')}
@ -246,10 +335,8 @@ export default function MergeModelsPanel() {
<IAIButton <IAIButton
onClick={mergeModelsHandler} onClick={mergeModelsHandler}
isLoading={isProcessing} isLoading={isLoading}
isDisabled={ isDisabled={modelOne === null || modelTwo === null}
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
}
> >
{t('modelManager.merge')} {t('modelManager.merge')}
</IAIButton> </IAIButton>

View File

@ -1,44 +1,60 @@
import { Flex } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { useState } from 'react';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList'; import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() { export default function ModelManagerPanel() {
const { data: mainModels } = useGetMainModelsQuery(); const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const renderModelEditTabs = () => {
if (!openModel || !mainModels) return;
if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
return ( return (
<DiffusersModelEdit <Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
modelToEdit={openModel} <ModelList
retrievedModel={mainModels['entities'][openModel]} selectedModelId={selectedModelId}
key={openModel} setSelectedModelId={setSelectedModelId}
/> />
); <ModelEdit model={model} />
} else {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={mainModels['entities'][openModel]}
key={openModel}
/>
);
}
};
return (
<Flex width="100%" columnGap={8}>
<ModelList />
{renderModelEditTabs()}
</Flex> </Flex>
); );
} }
type ModelEditProps = {
model: MainModelConfigEntity | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
const { model } = props;
if (model?.model_format === 'checkpoint') {
return <CheckpointModelEdit key={model.id} model={model} />;
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model} />;
}
return (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
maxH: 96,
userSelect: 'none',
}}
>
<Text variant="subtext">No Model Selected</Text>
</Flex>
);
};

View File

@ -1,17 +1,20 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react'; import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import { useTranslation } from 'react-i18next'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { RootState } from 'app/store/store';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
CheckpointModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert'; import ModelConvert from './ModelConvert';
const baseModelSelectData = [ const baseModelSelectData = [
@ -25,55 +28,92 @@ const variantSelectData = [
{ value: 'depth', label: 'Depth' }, { value: 'depth', label: 'Depth' },
]; ];
export type CheckpointModel =
| S<'StableDiffusion1ModelCheckpointConfig'>
| S<'StableDiffusion2ModelCheckpointConfig'>;
type CheckpointModelEditProps = { type CheckpointModelEditProps = {
modelToEdit: string; model: CheckpointModelConfigEntity;
retrievedModel: CheckpointModel;
}; };
export default function CheckpointModelEdit(props: CheckpointModelEditProps) { export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isProcessing = useAppSelector( const isBusy = useAppSelector(selectIsBusy);
(state: RootState) => state.system.isProcessing
);
const { modelToEdit, retrievedModel } = props; const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const checkpointEditForm = useForm({ const checkpointEditForm = useForm<CheckpointModelConfig>({
initialValues: { initialValues: {
name: retrievedModel.model_name, model_name: model.model_name ? model.model_name : '',
base_model: retrievedModel.base_model, base_model: model.base_model,
type: 'main', model_type: 'main',
path: retrievedModel.path, path: model.path ? model.path : '',
description: retrievedModel.description, description: model.description ? model.description : '',
model_format: 'checkpoint', model_format: 'checkpoint',
vae: retrievedModel.vae, vae: model.vae ? model.vae : '',
config: retrievedModel.config, config: model.config ? model.config : '',
variant: retrievedModel.variant, variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
}, },
}); });
const editModelFormSubmitHandler = (values) => { const editModelFormSubmitHandler = useCallback(
console.log(values); (values: CheckpointModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
}; };
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
checkpointEditForm.setValues(payload as CheckpointModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
checkpointEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
checkpointEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? ( return (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column"> <Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name} {model.model_name}
</Text> </Text>
<Text fontSize="sm" color="base.400"> <Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model {MODEL_TYPE_MAP[model.base_model]} Model
</Text> </Text>
</Flex> </Flex>
<ModelConvert model={retrievedModel} /> <ModelConvert model={model} />
</Flex> </Flex>
<Divider /> <Divider />
@ -88,11 +128,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput <IAIMantineTextInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('name')}
/>
<IAIInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')} {...checkpointEditForm.getInputProps('description')}
/> />
@ -106,36 +142,28 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
data={variantSelectData} data={variantSelectData}
{...checkpointEditForm.getInputProps('variant')} {...checkpointEditForm.getInputProps('variant')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.modelLocation')} label={t('modelManager.modelLocation')}
{...checkpointEditForm.getInputProps('path')} {...checkpointEditForm.getInputProps('path')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.vaeLocation')} label={t('modelManager.vaeLocation')}
{...checkpointEditForm.getInputProps('vae')} {...checkpointEditForm.getInputProps('vae')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.config')} label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')} {...checkpointEditForm.getInputProps('config')}
/> />
<IAIButton disabled={isProcessing} type="submit"> <IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')} {t('modelManager.updateModel')}
</IAIButton> </IAIButton>
</Flex> </Flex>
</form> </form>
</Flex> </Flex>
</Flex> </Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
); );
} }

View File

@ -1,25 +1,23 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react'; import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useTranslation } from 'react-i18next';
import { useForm } from '@mantine/form'; import { useForm } from '@mantine/form';
import type { RootState } from 'app/store/store'; import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
type DiffusersModel = import { useCallback } from 'react';
| S<'StableDiffusion1ModelDiffusersConfig'> import { useTranslation } from 'react-i18next';
| S<'StableDiffusion2ModelDiffusersConfig'>; import {
DiffusersModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
type DiffusersModelEditProps = { type DiffusersModelEditProps = {
modelToEdit: string; model: DiffusersModelConfigEntity;
retrievedModel: DiffusersModel;
}; };
const baseModelSelectData = [ const baseModelSelectData = [
@ -34,39 +32,82 @@ const variantSelectData = [
]; ];
export default function DiffusersModelEdit(props: DiffusersModelEditProps) { export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isProcessing = useAppSelector( const isBusy = useAppSelector(selectIsBusy);
(state: RootState) => state.system.isProcessing
); const { model } = props;
const { retrievedModel, modelToEdit } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const diffusersEditForm = useForm({ const diffusersEditForm = useForm<DiffusersModelConfig>({
initialValues: { initialValues: {
name: retrievedModel.model_name, model_name: model.model_name ? model.model_name : '',
base_model: retrievedModel.base_model, base_model: model.base_model,
type: 'main', model_type: 'main',
path: retrievedModel.path, path: model.path ? model.path : '',
description: retrievedModel.description, description: model.description ? model.description : '',
model_format: 'diffusers', model_format: 'diffusers',
vae: retrievedModel.vae, vae: model.vae ? model.vae : '',
variant: retrievedModel.variant, variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
}, },
}); });
const editModelFormSubmitHandler = (values) => { const editModelFormSubmitHandler = useCallback(
console.log(values); (values: DiffusersModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
}; };
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
diffusersEditForm.setValues(payload as DiffusersModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
diffusersEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
diffusersEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? ( return (
<Flex flexDirection="column" rowGap={4} width="100%"> <Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column"> <Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold"> <Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name} {model.model_name}
</Text> </Text>
<Text fontSize="sm" color="base.400"> <Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model {MODEL_TYPE_MAP[model.base_model]} Model
</Text> </Text>
</Flex> </Flex>
<Divider /> <Divider />
@ -77,11 +118,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)} )}
> >
<Flex flexDirection="column" overflowY="scroll" gap={4}> <Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput <IAIMantineTextInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('name')}
/>
<IAIInput
label={t('modelManager.description')} label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')} {...diffusersEditForm.getInputProps('description')}
/> />
@ -95,31 +132,23 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
data={variantSelectData} data={variantSelectData}
{...diffusersEditForm.getInputProps('variant')} {...diffusersEditForm.getInputProps('variant')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.modelLocation')} label={t('modelManager.modelLocation')}
{...diffusersEditForm.getInputProps('path')} {...diffusersEditForm.getInputProps('path')}
/> />
<IAIInput <IAIMantineTextInput
label={t('modelManager.vaeLocation')} label={t('modelManager.vaeLocation')}
{...diffusersEditForm.getInputProps('vae')} {...diffusersEditForm.getInputProps('vae')}
/> />
<IAIButton disabled={isProcessing} type="submit"> <IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')} {t('modelManager.updateModel')}
</IAIButton> </IAIButton>
</Flex> </Flex>
</form> </form>
</Flex> </Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
); );
} }

View File

@ -1,23 +1,18 @@
import { import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react';
Flex,
ListItem,
Radio,
RadioGroup,
Text,
Tooltip,
UnorderedList,
} from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions'; // import { convertToDiffusers } from 'app/socketio/actions';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { CheckpointModel } from './CheckpointModelEdit';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps { interface ModelConvertProps {
model: CheckpointModel; model: CheckpointModelConfig;
} }
export default function ModelConvert(props: ModelConvertProps) { export default function ModelConvert(props: ModelConvertProps) {
@ -26,6 +21,8 @@ export default function ModelConvert(props: ModelConvertProps) {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same'); const [saveLocation, setSaveLocation] = useState<string>('same');
const [customSaveLocation, setCustomSaveLocation] = useState<string>(''); const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
@ -38,20 +35,39 @@ export default function ModelConvert(props: ModelConvertProps) {
}; };
const modelConvertHandler = () => { const modelConvertHandler = () => {
const modelToConvert = { const responseBody = {
model_name: model, base_model: model.base_model,
save_location: saveLocation, model_name: model.model_name,
custom_location:
saveLocation === 'custom' && customSaveLocation !== ''
? customSaveLocation
: null,
}; };
dispatch(convertToDiffusers(modelToConvert)); convertModel(responseBody)
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConverted')}: ${model.model_name}`,
status: 'success',
})
)
);
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConversionFailed')}: ${
model.model_name
}`,
status: 'error',
})
)
);
});
}; };
return ( return (
<IAIAlertDialog <IAIAlertDialog
title={`${t('modelManager.convert')} ${model.name}`} title={`${t('modelManager.convert')} ${model.model_name}`}
acceptCallback={modelConvertHandler} acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler} cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`} acceptButtonText={`${t('modelManager.convert')}`}
@ -60,6 +76,7 @@ export default function ModelConvert(props: ModelConvertProps) {
size={'sm'} size={'sm'}
aria-label={t('modelManager.convertToDiffusers')} aria-label={t('modelManager.convertToDiffusers')}
className=" modal-close-btn" className=" modal-close-btn"
isLoading={isLoading}
> >
🧨 {t('modelManager.convertToDiffusers')} 🧨 {t('modelManager.convertToDiffusers')}
</IAIButton> </IAIButton>
@ -77,7 +94,7 @@ export default function ModelConvert(props: ModelConvertProps) {
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text> <Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex> </Flex>
<Flex flexDir="column" gap={4}> {/* <Flex flexDir="column" gap={4}>
<Flex marginTop={4} flexDir="column" gap={2}> <Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="600"> <Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')} {t('modelManager.convertToDiffusersSaveLocation')}
@ -103,9 +120,9 @@ export default function ModelConvert(props: ModelConvertProps) {
</Radio> </Radio>
</Flex> </Flex>
</RadioGroup> </RadioGroup>
</Flex> </Flex> */}
{saveLocation === 'custom' && ( {/* {saveLocation === 'custom' && (
<Flex flexDirection="column" rowGap={2}> <Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext"> <Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.customSaveLocation')} {t('modelManager.customSaveLocation')}
@ -119,8 +136,7 @@ export default function ModelConvert(props: ModelConvertProps) {
width="full" width="full"
/> />
</Flex> </Flex>
)} )} */}
</Flex>
</IAIAlertDialog> </IAIAlertDialog>
); );
} }

View File

@ -1,185 +1,46 @@
import { Box, Flex, Spinner, Text } from '@chakra-ui/react'; import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput'; import IAIInput from 'common/components/IAIInput';
import { forEach } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem'; import ModelListItem from './ModelListItem';
import { useTranslation } from 'react-i18next'; type ModelListProps = {
selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void;
};
import type { ChangeEvent, ReactNode } from 'react'; type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
import React, { useMemo, useState, useTransition } from 'react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
function ModelFilterButton({
label,
isActive,
onClick,
}: {
label: string;
isActive: boolean;
onClick: () => void;
}) {
return (
<IAIButton
onClick={onClick}
isActive={isActive}
sx={{
_active: {
bg: 'accent.750',
},
}}
size="sm"
>
{label}
</IAIButton>
);
}
const ModelList = () => {
const { data: mainModels } = useGetMainModelsQuery();
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
React.useEffect(() => {
const timer = setTimeout(() => {
setRenderModelList(true);
}, 200);
return () => clearTimeout(timer);
}, []);
const [searchText, setSearchText] = useState<string>('');
const [isSelectedFilter, setIsSelectedFilter] = useState<
'all' | 'ckpt' | 'diffusers'
>('all');
const [_, startTransition] = useTransition();
const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation(); const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('all');
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => { const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
startTransition(() => { selectFromResult: ({ data }) => ({
setSearchText(e.target.value); filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
}); }),
};
const renderModelListItems = useMemo(() => {
const ckptModelListItemsToRender: ReactNode[] = [];
const diffusersModelListItemsToRender: ReactNode[] = [];
const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = [];
if (!mainModels) return;
const modelList = mainModels.entities;
Object.keys(modelList).forEach((model, i) => {
if (
modelList[model]?.model_name
.toLowerCase()
.includes(searchText.toLowerCase())
) {
filteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
if (modelList[model]?.model_format === isSelectedFilter) {
localFilteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
}
}
if (modelList[model]?.model_format !== 'diffusers') {
ckptModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
} else {
diffusersModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
}
}); });
return searchText !== '' ? ( const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
isSelectedFilter === 'all' ? ( selectFromResult: ({ data }) => ({
<Box marginTop={4}>{filteredModelListItemsToRender}</Box> filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
) : ( }),
<Box marginTop={4}>{localFilteredModelListItemsToRender}</Box> });
)
) : (
<Flex flexDirection="column" rowGap={6}>
{isSelectedFilter === 'all' && (
<>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
my: 4,
mx: 0,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.checkpointModels')}
</Text>
{ckptModelListItemsToRender}
</Box>
</>
)}
{isSelectedFilter === 'diffusers' && ( const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
<Flex flexDirection="column" marginTop={4}> setNameFilter(e.target.value);
{diffusersModelListItemsToRender} }, []);
</Flex>
)}
{isSelectedFilter === 'ckpt' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
</Flex>
);
}, [mainModels, searchText, t, isSelectedFilter]);
return ( return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%"> <Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
@ -187,7 +48,6 @@ const ModelList = () => {
onChange={handleSearchFilter} onChange={handleSearchFilter}
label={t('modelManager.search')} label={t('modelManager.search')}
/> />
<Flex <Flex
flexDirection="column" flexDirection="column"
gap={4} gap={4}
@ -195,34 +55,60 @@ const ModelList = () => {
overflow="scroll" overflow="scroll"
paddingInlineEnd={4} paddingInlineEnd={4}
> >
<Flex columnGap={2}> <ButtonGroup isAttached>
<ModelFilterButton <IAIButton
label={t('modelManager.allModels')} onClick={() => setModelFormatFilter('all')}
onClick={() => setIsSelectedFilter('all')} isChecked={modelFormatFilter === 'all'}
isActive={isSelectedFilter === 'all'} size="sm"
/>
<ModelFilterButton
label={t('modelManager.diffusersModels')}
onClick={() => setIsSelectedFilter('diffusers')}
isActive={isSelectedFilter === 'diffusers'}
/>
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('ckpt')}
isActive={isSelectedFilter === 'ckpt'}
/>
</Flex>
{renderModelList ? (
renderModelListItems
) : (
<Flex
width="100%"
minHeight={96}
justifyContent="center"
alignItems="center"
> >
<Spinner /> {t('modelManager.allModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('diffusers')}
isChecked={modelFormatFilter === 'diffusers'}
>
{t('modelManager.diffusersModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('checkpoint')}
isChecked={modelFormatFilter === 'checkpoint'}
>
{t('modelManager.checkpointModels')}
</IAIButton>
</ButtonGroup>
{['all', 'diffusers'].includes(modelFormatFilter) &&
filteredDiffusersModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Diffusers
</Text>
{filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
filteredCheckpointModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Checkpoint
</Text>
{filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex> </Flex>
)} )}
</Flex> </Flex>
@ -231,3 +117,27 @@ const ModelList = () => {
}; };
export default ModelList; export default ModelList;
const modelsFilter = (
data: EntityState<MainModelConfigEntity> | undefined,
model_format: ModelFormat,
nameFilter: string
) => {
const filteredModels: MainModelConfigEntity[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.model_name
.toLowerCase()
.includes(nameFilter.toLowerCase());
const matchesFormat = model.model_format === model_format;
if (matchesFilter && matchesFormat) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -1,78 +1,71 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons'; import { DeleteIcon } from '@chakra-ui/icons';
import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react'; import { Flex, Text, Tooltip } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { setOpenModel } from 'features/system/store/systemSlice'; import { selectIsBusy } from 'features/system/store/systemSelectors';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useDeleteMainModelsMutation,
} from 'services/api/endpoints/models';
type ModelListItemProps = { type ModelListItemProps = {
modelKey: string; model: MainModelConfigEntity;
name: string; isSelected: boolean;
description: string | undefined; setSelectedModelId: (v: string | undefined) => void;
}; };
export default function ModelListItem(props: ModelListItemProps) { export default function ModelListItem(props: ModelListItemProps) {
const { isProcessing, isConnected } = useAppSelector( const isBusy = useAppSelector(selectIsBusy);
(state: RootState) => state.system
);
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const { t } = useTranslation(); const { t } = useTranslation();
const [deleteMainModel] = useDeleteMainModelsMutation();
const dispatch = useAppDispatch(); const { model, isSelected, setSelectedModelId } = props;
const { modelKey, name, description } = props; const handleSelectModel = useCallback(() => {
setSelectedModelId(model.id);
}, [model.id, setSelectedModelId]);
const openModelHandler = () => { const handleModelDelete = useCallback(() => {
dispatch(setOpenModel(modelKey)); deleteMainModel(model);
}; setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId]);
const handleModelDelete = () => {
dispatch(deleteModel(modelKey));
dispatch(setOpenModel(null));
};
return ( return (
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
<Flex <Flex
alignItems="center" as={IAIButton}
p={2} isChecked={isSelected}
borderRadius="base" sx={{
sx={ justifyContent: 'start',
modelKey === openModel p: 2,
? { borderRadius: 'base',
bg: 'accent.750', w: 'full',
alignItems: 'center',
bg: isSelected ? 'accent.400' : 'base.100',
color: isSelected ? 'base.50' : 'base.800',
_hover: { _hover: {
bg: 'accent.750', bg: isSelected ? 'accent.500' : 'base.200',
color: isSelected ? 'base.50' : 'base.800',
}, },
} _dark: {
: { color: isSelected ? 'base.50' : 'base.100',
bg: isSelected ? 'accent.600' : 'base.850',
_hover: { _hover: {
bg: 'base.750', color: isSelected ? 'base.50' : 'base.100',
bg: isSelected ? 'accent.550' : 'base.800',
}, },
} },
} }}
onClick={handleSelectModel}
> >
<Box onClick={openModelHandler} cursor="pointer"> <Tooltip label={model.description} hasArrow placement="bottom">
<Tooltip label={description} hasArrow placement="bottom"> <Text sx={{ fontWeight: 500 }}>{model.model_name}</Text>
<Text fontWeight="600">{name}</Text>
</Tooltip> </Tooltip>
</Box> </Flex>
<Spacer onClick={openModelHandler} cursor="pointer" />
<Flex gap={2} alignItems="center">
<IAIIconButton
icon={<EditIcon />}
size="sm"
onClick={openModelHandler}
aria-label={t('accessibility.modifyConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected}
/>
<IAIAlertDialog <IAIAlertDialog
title={t('modelManager.deleteModel')} title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete} acceptCallback={handleModelDelete}
@ -80,9 +73,8 @@ export default function ModelListItem(props: ModelListItemProps) {
triggerComponent={ triggerComponent={
<IAIIconButton <IAIIconButton
icon={<DeleteIcon />} icon={<DeleteIcon />}
size="sm"
aria-label={t('modelManager.deleteConfig')} aria-label={t('modelManager.deleteConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected} isDisabled={isBusy}
colorScheme="error" colorScheme="error"
/> />
} }
@ -93,6 +85,5 @@ export default function ModelListItem(props: ModelListItemProps) {
</Flex> </Flex>
</IAIAlertDialog> </IAIAlertDialog>
</Flex> </Flex>
</Flex>
); );
} }

View File

@ -9,17 +9,15 @@ import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { uiSelector } from 'features/ui/store/uiSelectors'; import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useLayoutEffect } from 'react';
import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBeta';
import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta';
import { ImageDTO } from 'services/api/types';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { import {
CanvasInitialImageDropData, CanvasInitialImageDropData,
isValidDrop, isValidDrop,
useDroppable, useDroppable,
} from 'app/components/ImageDnd/typesafeDnd'; } from 'app/components/ImageDnd/typesafeDnd';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { memo, useLayoutEffect } from 'react';
import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta';
import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBeta';
const selector = createSelector( const selector = createSelector(
[canvasSelector, uiSelector], [canvasSelector, uiSelector],

View File

@ -0,0 +1,140 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
export const useMantineMultiSelectStyles = () => {
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useCallback(
() => ({
label: {
color: mode(base700, base300)(colorMode),
},
separatorLabel: {
color: mode(base500, base500)(colorMode),
'::after': { borderTopColor: mode(base300, base700)(colorMode) },
},
searchInput: {
':placeholder': {
color: mode(base300, base700)(colorMode),
},
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: mode(base900, base100)(colorMode),
},
},
}),
[
accent200,
accent300,
accent400,
accent500,
accent600,
base100,
base200,
base300,
base400,
base50,
base500,
base600,
base700,
base800,
base900,
boxShadow,
colorMode,
]
);
return styles;
};

View File

@ -0,0 +1,134 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
export const useMantineSelectStyles = () => {
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useCallback(
() => ({
label: {
color: mode(base700, base300)(colorMode),
},
separatorLabel: {
color: mode(base500, base500)(colorMode),
'::after': { borderTopColor: mode(base300, base700)(colorMode) },
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base100, base900)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,
button: {
color: mode(base900, base100)(colorMode),
},
},
}),
[
accent200,
accent300,
accent400,
accent500,
accent600,
base100,
base200,
base300,
base400,
base50,
base500,
base600,
base700,
base800,
base900,
boxShadow,
colorMode,
]
);
return styles;
};

View File

@ -1,6 +1,9 @@
import { MantineThemeOverride } from '@mantine/core'; import { MantineThemeOverride } from '@mantine/core';
import { useMemo } from 'react';
export const mantineTheme: MantineThemeOverride = { export const useMantineTheme = () => {
const mantineTheme: MantineThemeOverride = useMemo(
() => ({
colorScheme: 'dark', colorScheme: 'dark',
fontFamily: `'Inter Variable', sans-serif`, fontFamily: `'Inter Variable', sans-serif`,
components: { components: {
@ -20,4 +23,9 @@ export const mantineTheme: MantineThemeOverride = {
}, },
}, },
}, },
}),
[]
);
return mantineTheme;
}; };

View File

@ -25,11 +25,12 @@ export const boardImagesApi = api.injectEndpoints({
query: ({ board_id, offset, limit }) => ({ query: ({ board_id, offset, limit }) => ({
url: `board_images/${board_id}`, url: `board_images/${board_id}`,
method: 'GET', method: 'GET',
}), }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
// any list of boardimages // any list of boardimages
const tags: ApiFullTagDescription[] = [{ id: 'BoardImage', type: `${arg.board_id}_${LIST_TAG}` }]; const tags: ApiFullTagDescription[] = [
{ type: 'BoardImage', id: `${arg.board_id}_${LIST_TAG}` },
];
if (result) { if (result) {
// and individual tags for each boardimage // and individual tags for each boardimage
@ -57,7 +58,7 @@ export const boardImagesApi = api.injectEndpoints({
}), }),
invalidatesTags: (result, error, arg) => [ invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' }, { type: 'BoardImage' },
{ type: 'Board', id: arg.board_id } { type: 'Board', id: arg.board_id },
], ],
}), }),
@ -69,7 +70,7 @@ export const boardImagesApi = api.injectEndpoints({
}), }),
invalidatesTags: (result, error, arg) => [ invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' }, { type: 'BoardImage' },
{ type: 'Board', id: arg.board_id } { type: 'Board', id: arg.board_id },
], ],
}), }),
}), }),

View File

@ -20,7 +20,7 @@ export const boardsApi = api.injectEndpoints({
query: (arg) => ({ url: 'boards/', params: arg }), query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
// any list of boards // any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }]; const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) { if (result) {
// and individual tags for each board // and individual tags for each board
@ -43,7 +43,7 @@ export const boardsApi = api.injectEndpoints({
}), }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
// any list of boards // any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }]; const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) { if (result) {
// and individual tags for each board // and individual tags for each board
@ -69,7 +69,7 @@ export const boardsApi = api.injectEndpoints({
method: 'POST', method: 'POST',
params: { board_name }, params: { board_name },
}), }),
invalidatesTags: [{ id: 'Board', type: LIST_TAG }], invalidatesTags: [{ type: 'Board', id: LIST_TAG }],
}), }),
updateBoard: build.mutation<BoardDTO, UpdateBoardArg>({ updateBoard: build.mutation<BoardDTO, UpdateBoardArg>({
@ -87,8 +87,15 @@ export const boardsApi = api.injectEndpoints({
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }], invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }],
}), }),
deleteBoardAndImages: build.mutation<void, string>({ deleteBoardAndImages: build.mutation<void, string>({
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE', params: { include_images: true } }), query: (board_id) => ({
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }, { type: 'Image', id: LIST_TAG }], url: `boards/${board_id}`,
method: 'DELETE',
params: { include_images: true },
}),
invalidatesTags: (result, error, arg) => [
{ type: 'Board', id: arg },
{ type: 'Image', id: LIST_TAG },
],
}), }),
}), }),
}); });
@ -99,5 +106,5 @@ export const {
useCreateBoardMutation, useCreateBoardMutation,
useUpdateBoardMutation, useUpdateBoardMutation,
useDeleteBoardMutation, useDeleteBoardMutation,
useDeleteBoardAndImagesMutation useDeleteBoardAndImagesMutation,
} = boardsApi; } = boardsApi;

View File

@ -2,16 +2,27 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { cloneDeep } from 'lodash-es'; import { cloneDeep } from 'lodash-es';
import { import {
AnyModelConfig, AnyModelConfig,
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig, ControlNetModelConfig,
DiffusersModelConfig,
LoRAModelConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
MergeModelConfig,
TextualInversionModelConfig, TextualInversionModelConfig,
VaeModelConfig, VaeModelConfig,
} from 'services/api/types'; } from 'services/api/types';
import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
export type MainModelConfigEntity = MainModelConfig & { id: string }; export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity =
| DiffusersModelConfigEntity
| CheckpointModelConfigEntity;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
@ -32,6 +43,38 @@ type AnyModelConfigEntity =
| TextualInversionModelConfigEntity | TextualInversionModelConfigEntity
| VaeModelConfigEntity; | VaeModelConfigEntity;
type UpdateMainModelArg = {
base_model: BaseModelType;
model_name: string;
body: MainModelConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type DeleteMainModelArg = {
base_model: BaseModelType;
model_name: string;
};
type DeleteMainModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
};
type ConvertMainModelResponse =
paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json'];
type MergeMainModelArg = {
base_model: BaseModelType;
body: MergeModelConfig;
};
type MergeMainModelResponse =
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({ const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
@ -76,7 +119,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'main' } }), query: () => ({ url: 'models/', params: { model_type: 'main' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'MainModel', type: LIST_TAG }, { type: 'MainModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -104,11 +147,58 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
updateMainModels: build.mutation<
UpdateMainModelResponse,
UpdateMainModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
DeleteMainModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
convertMainModels: build.mutation<
ConvertMainModelResponse,
ConvertMainModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/convert/${base_model}/main/${model_name}`,
method: 'PUT',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
return {
url: `models/merge/${base_model}`,
method: 'PUT',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({ getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }), query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'LoRAModel', type: LIST_TAG }, { type: 'LoRAModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -143,7 +233,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'ControlNetModel', type: LIST_TAG }, { type: 'ControlNetModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -175,7 +265,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }), query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'VaeModel', type: LIST_TAG }, { type: 'VaeModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -210,7 +300,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => { providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [ const tags: ApiFullTagDescription[] = [
{ id: 'TextualInversionModel', type: LIST_TAG }, { type: 'TextualInversionModel', id: LIST_TAG },
]; ];
if (result) { if (result) {
@ -247,4 +337,8 @@ export const {
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery, useGetTextualInversionModelsQuery,
useGetVaeModelsQuery, useGetVaeModelsQuery,
useUpdateMainModelsMutation,
useDeleteMainModelsMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
} = modelsApi; } = modelsApi;

View File

@ -3290,7 +3290,7 @@ export type components = {
/** ModelsList */ /** ModelsList */
ModelsList: { ModelsList: {
/** Models */ /** Models */
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[]; models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
}; };
/** /**
* MultiplyInvocation * MultiplyInvocation
@ -4605,18 +4605,18 @@ export type components = {
*/ */
image?: components["schemas"]["ImageField"]; image?: components["schemas"]["ImageField"];
}; };
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/** /**
* StableDiffusion2ModelFormat * StableDiffusion2ModelFormat
* @description An enumeration. * @description An enumeration.
* @enum {string} * @enum {string}
*/ */
StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
}; };
responses: never; responses: never;
parameters: never; parameters: never;
@ -4997,7 +4997,7 @@ export type operations = {
/** @description The model imported successfully */ /** @description The model imported successfully */
201: { 201: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description The model could not be found */ /** @description The model could not be found */
@ -5065,14 +5065,14 @@ export type operations = {
}; };
requestBody: { requestBody: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
responses: { responses: {
/** @description The model was updated successfully */ /** @description The model was updated successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description Bad request */ /** @description Bad request */
@ -5106,7 +5106,7 @@ export type operations = {
/** @description Model converted successfully */ /** @description Model converted successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description Bad request */ /** @description Bad request */
@ -5141,7 +5141,7 @@ export type operations = {
/** @description Model converted successfully */ /** @description Model converted successfully */
200: { 200: {
content: { content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"]; "application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
}; };
}; };
/** @description Incompatible models */ /** @description Incompatible models */

View File

@ -42,17 +42,20 @@ export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig']; components['schemas']['ControlNetModelConfig'];
export type TextualInversionModelConfig = export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig']; components['schemas']['TextualInversionModelConfig'];
export type MainModelConfig = export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion1ModelDiffusersConfig'] | components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig']; | components['schemas']['StableDiffusion2ModelDiffusersConfig'];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig = export type AnyModelConfig =
| LoRAModelConfig | LoRAModelConfig
| VaeModelConfig | VaeModelConfig
| ControlNetModelConfig | ControlNetModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig; | MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models'];
// Graphs // Graphs
export type Graph = components['schemas']['Graph']; export type Graph = components['schemas']['Graph'];

View File

@ -19,16 +19,8 @@ const invokeAI = defineStyle((props) => {
bg: mode('base.200', 'base.600')(props), bg: mode('base.200', 'base.600')(props),
color: mode('base.850', 'base.100')(props), color: mode('base.850', 'base.100')(props),
borderRadius: 'base', borderRadius: 'base',
textShadow: mode(
'0 0 0.3rem var(--invokeai-colors-base-50)',
'0 0 0.3rem var(--invokeai-colors-base-900)'
)(props),
svg: { svg: {
fill: mode('base.850', 'base.100')(props), fill: mode('base.850', 'base.100')(props),
filter: mode(
'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-100))',
'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-800))'
)(props),
}, },
_hover: { _hover: {
bg: mode('base.300', 'base.500')(props), bg: mode('base.300', 'base.500')(props),
@ -57,16 +49,8 @@ const invokeAI = defineStyle((props) => {
bg: mode(`${c}.400`, `${c}.600`)(props), bg: mode(`${c}.400`, `${c}.600`)(props),
color: mode(`base.50`, `base.100`)(props), color: mode(`base.50`, `base.100`)(props),
borderRadius: 'base', borderRadius: 'base',
textShadow: mode(
`0 0 0.3rem var(--invokeai-colors-${c}-600)`,
`0 0 0.3rem var(--invokeai-colors-${c}-800)`
)(props),
svg: { svg: {
fill: mode(`base.50`, `base.100`)(props), fill: mode(`base.50`, `base.100`)(props),
filter: mode(
`drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-600))`,
`drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-800))`
)(props),
}, },
_disabled, _disabled,
_hover: { _hover: {

View File

@ -2,7 +2,7 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react';
import { mode } from '@chakra-ui/theme-tools'; import { mode } from '@chakra-ui/theme-tools';
const subtext = defineStyle((props) => ({ const subtext = defineStyle((props) => ({
color: mode('colors.base.500', 'colors.base.400')(props), color: mode('base.500', 'base.400')(props),
})); }));
export const textTheme = defineStyleConfig({ export const textTheme = defineStyleConfig({