Merge branch 'main' into fix/nodes/fix-mouse-interactions

This commit is contained in:
blessedcoolant 2023-07-15 04:13:46 +12:00 committed by GitHub
commit 48561908b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 1270 additions and 981 deletions

View File

@ -589,6 +589,7 @@ class ModelManager(object):
rmtree(str(model_path))
else:
model_path.unlink()
self.commit()
# LS: tested
def add_model(

View File

@ -343,6 +343,7 @@
"safetensorModels": "SafeTensors",
"modelAdded": "Model Added",
"modelUpdated": "Model Updated",
"modelUpdateFailed": "Model Update Failed",
"modelEntryDeleted": "Model Entry Deleted",
"cannotUseSpaces": "Cannot Use Spaces",
"addNew": "Add New",
@ -397,8 +398,8 @@
"delete": "Delete",
"deleteModel": "Delete Model",
"deleteConfig": "Delete Config",
"deleteMsg1": "Are you sure you want to delete this model entry from InvokeAI?",
"deleteMsg2": "This will not delete the model checkpoint file from your disk. You can readd them if you wish to.",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"formMessageDiffusersModelLocation": "Diffusers Model Location",
"formMessageDiffusersModelLocationDesc": "Please enter at least one.",
"formMessageDiffusersVAELocation": "VAE Location",
@ -409,7 +410,7 @@
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on the disk will NOT be deleted or modified in anyway. You can add your checkpoint to the Model Manager again if you want to.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 4GB-7GB in size.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"v1": "v1",
@ -420,12 +421,14 @@
"pathToCustomConfig": "Path To Custom Config",
"statusConverting": "Converting",
"modelConverted": "Model Converted",
"modelConversionFailed": "Model Conversion Failed",
"sameFolder": "Same folder",
"invokeRoot": "InvokeAI folder",
"custom": "Custom",
"customSaveLocation": "Custom Save Location",
"merge": "Merge",
"modelsMerged": "Models Merged",
"modelsMergeFailed": "Model Merge Failed",
"mergeModels": "Merge Models",
"modelOne": "Model 1",
"modelTwo": "Model 2",
@ -446,7 +449,8 @@
"weightedSum": "Weighted Sum",
"none": "none",
"addDifference": "Add Difference",
"pickModelType": "Pick Model Type"
"pickModelType": "Pick Model Type",
"selectModel": "Select Model"
},
"parameters": {
"general": "General",
@ -599,7 +603,6 @@
"nodesLoaded": "Nodes Loaded",
"nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared"
},
"tooltip": {
"feature": {

View File

@ -9,9 +9,9 @@ import { theme as invokeAITheme } from 'theme/theme';
import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core';
import { mantineTheme } from 'mantine-theme/theme';
import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css';
import { useMantineTheme } from 'mantine-theme/theme';
type ThemeLocaleProviderProps = {
children: ReactNode;
@ -35,8 +35,10 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
document.body.dir = direction;
}, [direction]);
const mantineTheme = useMantineTheme();
return (
<MantineProvider withGlobalStyles theme={mantineTheme}>
<MantineProvider theme={mantineTheme}>
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>

View File

@ -1,6 +1,7 @@
import { log } from 'app/logging/useLogger';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { modelsApi } from 'services/api/endpoints/models';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { startAppListening } from '../..';
const moduleLog = log.child({ namespace: 'socketio' });
@ -23,6 +24,13 @@ export const addSocketConnectedEventListener = () => {
// pass along the socket event as an application action
dispatch(appSocketConnected(action.payload));
// update all server state
dispatch(modelsApi.endpoints.getMainModels.initiate());
dispatch(modelsApi.endpoints.getControlNetModels.initiate());
dispatch(modelsApi.endpoints.getLoRAModels.initiate());
dispatch(modelsApi.endpoints.getTextualInversionModels.initiate());
dispatch(modelsApi.endpoints.getVaeModels.initiate());
},
});
};

View File

@ -100,10 +100,11 @@ export const store = configureStore({
// manually type state, cannot type the arg
// const typedState = state as ReturnType<typeof rootReducer>;
if (action.type.startsWith('api/')) {
// don't log api actions, with manual cache updates they are extremely noisy
return false;
}
// TODO: doing this breaks the rtk query devtools, commenting out for now
// if (action.type.startsWith('api/')) {
// // don't log api actions, with manual cache updates they are extremely noisy
// return false;
// }
if (actionsDenylist.includes(action.type)) {
// don't log other noisy actions

View File

@ -1,11 +1,13 @@
import { Flex } from '@chakra-ui/react';
import { Flex, useColorMode } from '@chakra-ui/react';
import { ReactElement } from 'react';
import { mode } from 'theme/util/mode';
export function IAIFormItemWrapper({
children,
}: {
children: ReactElement | ReactElement[];
}) {
const { colorMode } = useColorMode();
return (
<Flex
sx={{
@ -14,7 +16,7 @@ export function IAIFormItemWrapper({
rowGap: 4,
borderRadius: 'base',
width: 'full',
bg: 'base.900',
bg: mode('base.100', 'base.900')(colorMode),
}}
>
{children}

View File

@ -0,0 +1,44 @@
import { useColorMode } from '@chakra-ui/react';
import { TextInput, TextInputProps } from '@mantine/core';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { mode } from 'theme/util/mode';
type IAIMantineTextInputProps = TextInputProps;
export default function IAIMantineTextInput(props: IAIMantineTextInputProps) {
const { ...rest } = props;
const {
base50,
base100,
base200,
base300,
base800,
base700,
base900,
accent500,
accent300,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
return (
<TextInput
styles={() => ({
input: {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base50, base900)(colorMode),
borderColor: mode(base200, base800)(colorMode),
borderWidth: 2,
outline: 'none',
':focus': {
borderColor: mode(accent300, accent500)(colorMode),
},
},
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
})}
{...rest}
/>
);
}

View File

@ -1,10 +1,9 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { Tooltip } from '@chakra-ui/react';
import { MultiSelect, MultiSelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineMultiSelectStyles } from 'mantine-theme/hooks/useMantineMultiSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback } from 'react';
import { mode } from 'theme/util/mode';
type IAIMultiSelectProps = MultiSelectProps & {
tooltip?: string;
@ -14,25 +13,6 @@ type IAIMultiSelectProps = MultiSelectProps & {
const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
const { searchable = true, tooltip, inputRef, ...rest } = props;
const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const { colorMode } = useColorMode();
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
@ -52,6 +32,8 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
[dispatch]
);
const styles = useMantineMultiSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow isOpen={true}>
<MultiSelect
@ -60,92 +42,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => {
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={() => ({
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
searchInput: {
':placeholder': {
color: mode(base300, base700)(colorMode),
},
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
styles={styles}
{...rest}
/>
</Tooltip>

View File

@ -0,0 +1,78 @@
import { Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
export type IAISelectDataType = {
value: string;
label: string;
tooltip?: string;
};
type IAISelectProps = SelectProps & {
tooltip?: string;
inputRef?: RefObject<HTMLInputElement>;
};
const IAIMantineSearchableSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const dispatch = useAppDispatch();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={styles}
{...rest}
/>
</Tooltip>
);
};
export default memo(IAIMantineSearchableSelect);

View File

@ -1,10 +1,7 @@
import { Tooltip, useColorMode, useToken } from '@chakra-ui/react';
import { Tooltip } from '@chakra-ui/react';
import { Select, SelectProps } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import { KeyboardEvent, RefObject, memo, useCallback, useState } from 'react';
import { mode } from 'theme/util/mode';
import { useMantineSelectStyles } from 'mantine-theme/hooks/useMantineSelectStyles';
import { RefObject, memo } from 'react';
export type IAISelectDataType = {
value: string;
@ -18,157 +15,13 @@ type IAISelectProps = SelectProps & {
};
const IAIMantineSelect = (props: IAISelectProps) => {
const { searchable = true, tooltip, inputRef, onChange, ...rest } = props;
const dispatch = useAppDispatch();
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { tooltip, inputRef, ...rest } = props;
const { colorMode } = useColorMode();
const [searchValue, setSearchValue] = useState('');
// we want to capture shift keypressed even when an input is focused
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (e.shiftKey) {
dispatch(shiftKeyPressed(true));
}
},
[dispatch]
);
const handleKeyUp = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
if (!e.shiftKey) {
dispatch(shiftKeyPressed(false));
}
},
[dispatch]
);
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
setSearchValue('');
if (!onChange) {
return;
}
onChange(v);
},
[onChange]
);
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useMantineSelectStyles();
return (
<Tooltip label={tooltip} placement="top" hasArrow>
<Select
ref={inputRef}
searchValue={searchValue}
onSearchChange={setSearchValue}
onChange={handleChange}
onKeyDown={handleKeyDown}
onKeyUp={handleKeyUp}
searchable={searchable}
maxDropdownHeight={300}
styles={() => ({
label: {
color: mode(base700, base300)(colorMode),
fontWeight: 'normal',
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base100, base900)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,
button: {
color: mode(base900, base100)(colorMode),
},
},
})}
{...rest}
/>
<Select ref={inputRef} styles={styles} {...rest} />
</Tooltip>
);
};

View File

@ -24,7 +24,7 @@ import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash-es';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
canvasCopiedToClipboard,
canvasDownloadedAsImage,
@ -213,7 +213,7 @@ const IAICanvasToolbar = () => {
}}
>
<Box w={24}>
<IAIMantineSelect
<IAIMantineSearchableSelect
tooltip={`${t('unifiedCanvas.layer')} (Q)`}
value={layer}
data={LAYER_NAMES_DICT}

View File

@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, {
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
} from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import {
CONTROLNET_MODELS,
@ -48,7 +48,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
data={controlNetModels}
value={model}
onChange={handleModelChanged}

View File

@ -1,8 +1,12 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect, {
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSearchableSelect, {
IAISelectDataType,
} from 'common/components/IAIMantineSelect';
} from 'common/components/IAIMantineSearchableSelect';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { configSelector } from 'features/system/store/configSelectors';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
@ -11,10 +15,6 @@ import {
ControlNetProcessorNode,
ControlNetProcessorType,
} from '../../store/types';
import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke';
import { createSelector } from '@reduxjs/toolkit';
import { configSelector } from 'features/system/store/configSelectors';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
@ -72,7 +72,7 @@ const ParamControlNetProcessorSelect = (
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label="Processor"
value={processorNode.type ?? 'canny_image_processor'}
data={controlNetProcessors}

View File

@ -9,7 +9,7 @@ import {
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
@ -106,7 +106,7 @@ const ParamEmbeddingPopover = (props: Props) => {
</Text>
</Flex>
) : (
<IAIMantineSelect
<IAIMantineSearchableSelect
inputRef={inputRef}
autoFocus
placeholder={'Add Embedding'}

View File

@ -12,10 +12,10 @@ import {
} from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { memo, useContext, useRef, useState } from 'react';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { AddImageToBoardContext } from '../../../../app/contexts/AddImageToBoardContext';
const UpdateImageBoardModal = () => {
// const boards = useSelector(selectBoardsAll);
@ -56,7 +56,7 @@ const UpdateImageBoardModal = () => {
{isFetching ? (
<Spinner />
) : (
<IAIMantineSelect
<IAIMantineSearchableSelect
placeholder="Select Board"
onChange={(v) => setSelectedBoard(v)}
value={selectedBoard}

View File

@ -4,7 +4,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { RootState, stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { loraAdded } from 'features/lora/store/loraSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
@ -84,7 +84,7 @@ const ParamLoRASelect = () => {
}
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
placeholder={data.length === 0 ? 'All LoRAs added' : 'Add LoRA'}
value={null}
data={data}

View File

@ -3,7 +3,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { map } from 'lodash-es';
import { forwardRef, useCallback } from 'react';
import 'reactflow/dist/style.css';
@ -77,7 +77,7 @@ const AddNodeMenu = () => {
return (
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAIMantineSelect
<IAIMantineSearchableSelect
selectOnBlur={false}
placeholder="Add Node"
value={null}

View File

@ -1,7 +1,7 @@
import { Flex, Text } from '@chakra-ui/react';
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
@ -89,7 +89,7 @@ const LoRAModelInputFieldComponent = (
}
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
value={selectedLoRAModel?.id ?? null}
label={
selectedLoRAModel?.base_model &&

View File

@ -6,7 +6,7 @@ import {
} from 'features/nodes/types/types';
import { SelectItem } from '@mantine/core';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam';
import { forEach } from 'lodash-es';
@ -81,14 +81,14 @@ const ModelInputFieldComponent = (
);
return isLoading ? (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<IAIMantineSelect
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={
selectedModel?.base_model && MODEL_TYPE_MAP[selectedModel?.base_model]

View File

@ -1,6 +1,6 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
@ -86,7 +86,7 @@ const VaeModelInputFieldComponent = (
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { setBoundingBoxScaleMethod } from 'features/canvas/store/canvasSlice';
import {
@ -35,7 +35,7 @@ const ParamScaleBeforeProcessing = () => {
};
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('parameters.scaleBeforeProcessing')}
data={BOUNDING_BOX_SCALES_DICT}
value={boundingBoxScale}

View File

@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { SCHEDULER_LABEL_MAP, SCHEDULER_NAMES } from 'app/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { setScheduler } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
@ -48,7 +48,7 @@ const ParamScheduler = () => {
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('parameters.scheduler')}
value={scheduler}
data={data}

View File

@ -1,7 +1,7 @@
import { FACETOOL_TYPES } from 'app/constants';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
FacetoolType,
setFacetoolType,
@ -20,7 +20,7 @@ export default function FaceRestoreType() {
dispatch(setFacetoolType(v as FacetoolType));
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('parameters.type')}
data={FACETOOL_TYPES.concat()}
value={facetoolType}

View File

@ -2,7 +2,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core';
import { createSelector } from '@reduxjs/toolkit';
@ -77,14 +77,14 @@ const ParamMainModelSelect = () => {
);
return isLoading ? (
<IAIMantineSelect
<IAIMantineSearchableSelect
label={t('modelManager.model')}
placeholder="Loading..."
disabled={true}
data={[]}
/>
) : (
<IAIMantineSelect
<IAIMantineSearchableSelect
tooltip={selectedModel?.description}
label={t('modelManager.model')}
value={selectedModel?.id}

View File

@ -1,7 +1,7 @@
import { UPSCALING_LEVELS } from 'app/constants';
import type { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import {
UpscalingLevel,
setUpscalingLevel,
@ -24,7 +24,7 @@ export default function UpscaleScale() {
dispatch(setUpscalingLevel(Number(v) as UpscalingLevel));
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
disabled={!isESRGANAvailable}
label={t('parameters.scale')}
value={String(upscalingLevel)}

View File

@ -2,7 +2,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { SelectItem } from '@mantine/core';
import { forEach } from 'lodash-es';
@ -92,7 +92,7 @@ const ParamVAEModelSelect = () => {
);
return (
<IAIMantineSelect
<IAIMantineSearchableSelect
itemComponent={IAIMantineSelectItemWithTooltip}
tooltip={selectedVaeModel?.description}
label={t('modelManager.vae')}

View File

@ -13,7 +13,6 @@ export const systemPersistDenylist: (keyof SystemState)[] = [
'isProcessing',
'totalIterations',
'totalSteps',
'openModel',
'isCancelScheduled',
'progressImage',
'wereModelsReceived',

View File

@ -1,7 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { reduce, pickBy } from 'lodash-es';
import { pickBy, reduce } from 'lodash-es';
export const systemSelector = (state: RootState) => state.system;
@ -50,3 +50,8 @@ export const languageSelector = createSelector(
export const isProcessingSelector = (state: RootState) =>
state.system.isProcessing;
export const selectIsBusy = createSelector(
(state: RootState) => state,
(state) => state.system.isProcessing || !state.system.isConnected
);

View File

@ -46,7 +46,6 @@ export interface SystemState {
toastQueue: UseToastOptions[];
searchFolder: string | null;
foundModels: InvokeAI.FoundModel[] | null;
openModel: string | null;
/**
* The current progress image
*/
@ -109,7 +108,6 @@ export const initialSystemState: SystemState = {
toastQueue: [],
searchFolder: null,
foundModels: null,
openModel: null,
progressImage: null,
shouldAntialiasProgressImage: false,
sessionId: null,
@ -164,9 +162,6 @@ export const systemSlice = createSlice({
) => {
state.foundModels = action.payload;
},
setOpenModel: (state, action: PayloadAction<string | null>) => {
state.openModel = action.payload;
},
/**
* A cancel was scheduled
*/
@ -433,7 +428,6 @@ export const {
clearToastQueue,
setSearchFolder,
setFoundModels,
setOpenModel,
cancelScheduled,
scheduledCancelAborted,
cancelTypeChanged,

View File

@ -13,7 +13,7 @@ type ModelManagerTabInfo = {
content: ReactNode;
};
const modelManagerTabs: ModelManagerTabInfo[] = [
const tabs: ModelManagerTabInfo[] = [
{
id: 'modelManager',
label: i18n.t('modelManager.modelManager'),
@ -31,49 +31,28 @@ const modelManagerTabs: ModelManagerTabInfo[] = [
},
];
const renderTabsList = () => {
const modelManagerTabListsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabListsToRender.push(
<Tab key={modelManagerTab.id}>{modelManagerTab.label}</Tab>
);
});
return (
<TabList
sx={{
w: '100%',
color: 'base.200',
flexDirection: 'row',
borderBottomWidth: 2,
borderColor: 'accent.700',
}}
>
{modelManagerTabListsToRender}
</TabList>
);
};
const renderTabPanels = () => {
const modelManagerTabPanelsToRender: ReactNode[] = [];
modelManagerTabs.forEach((modelManagerTab) => {
modelManagerTabPanelsToRender.push(
<TabPanel key={modelManagerTab.id}>{modelManagerTab.content}</TabPanel>
);
});
return <TabPanels sx={{ p: 2 }}>{modelManagerTabPanelsToRender}</TabPanels>;
};
const ModelManagerTab = () => {
return (
<Tabs
isLazy
variant="invokeAI"
sx={{ w: 'full', h: 'full', p: 2, gap: 4, flexDirection: 'column' }}
variant="line"
layerStyle="first"
sx={{ w: 'full', h: 'full', p: 4, gap: 4, borderRadius: 'base' }}
>
{renderTabsList()}
{renderTabPanels()}
<TabList>
{tabs.map((tab) => (
<Tab sx={{ borderTopRadius: 'base' }} key={tab.id}>
{tab.label}
</Tab>
))}
</TabList>
<TabPanels sx={{ w: 'full', h: 'full', p: 4 }}>
{tabs.map((tab) => (
<TabPanel sx={{ w: 'full', h: 'full' }} key={tab.id}>
{tab.content}
</TabPanel>
))}
</TabPanels>
</Tabs>
);
};

View File

@ -1,4 +1,4 @@
import { Divider, Flex } from '@chakra-ui/react';
import { Divider, Flex, useColorMode } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
@ -12,6 +12,8 @@ export default function AddModelsPanel() {
(state: RootState) => state.ui.addNewModelUIOption
);
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -20,27 +22,13 @@ export default function AddModelsPanel() {
<Flex columnGap={4}>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('ckpt'))}
sx={{
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'ckpt' ? 'accent.700' : 'base.600',
},
}}
isChecked={addNewModelUIOption == 'ckpt'}
>
{t('modelManager.addCheckpointModel')}
</IAIButton>
<IAIButton
onClick={() => dispatch(setAddNewModelUIOption('diffusers'))}
sx={{
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.700',
'&:hover': {
backgroundColor:
addNewModelUIOption == 'diffusers' ? 'accent.700' : 'base.600',
},
}}
isChecked={addNewModelUIOption == 'diffusers'}
>
{t('modelManager.addDiffuserModel')}
</IAIButton>

View File

@ -66,13 +66,13 @@ export default function AddDiffusersModel() {
};
return (
<Flex overflow="scroll" maxHeight={window.innerHeight - 270}>
<Flex overflow="scroll" maxHeight={window.innerHeight - 270} width="100%">
<Formik
initialValues={addModelFormValues}
onSubmit={addModelFormSubmitHandler}
>
{({ handleSubmit, errors, touched }) => (
<IAIForm onSubmit={handleSubmit}>
<IAIForm onSubmit={handleSubmit} w="full">
<VStack rowGap={2}>
<IAIFormItemWrapper>
{/* Name */}
@ -90,7 +90,6 @@ export default function AddDiffusersModel() {
name="name"
type="text"
validate={baseValidation}
width="2xl"
isRequired
/>
{!!errors.name && touched.name ? (
@ -119,7 +118,6 @@ export default function AddDiffusersModel() {
id="description"
name="description"
type="text"
width="2xl"
isRequired
/>
{!!errors.description && touched.description ? (
@ -153,13 +151,7 @@ export default function AddDiffusersModel() {
{t('modelManager.modelLocation')}
</FormLabel>
<VStack alignItems="start">
<Field
as={IAIInput}
id="path"
name="path"
type="text"
width="2xl"
/>
<Field as={IAIInput} id="path" name="path" type="text" />
{!!errors.path && touched.path ? (
<FormErrorMessage>{errors.path}</FormErrorMessage>
) : (
@ -181,7 +173,6 @@ export default function AddDiffusersModel() {
id="repo_id"
name="repo_id"
type="text"
width="2xl"
/>
{!!errors.repo_id && touched.repo_id ? (
<FormErrorMessage>{errors.repo_id}</FormErrorMessage>
@ -220,7 +211,6 @@ export default function AddDiffusersModel() {
id="vae.path"
name="vae.path"
type="text"
width="2xl"
/>
{!!errors.vae?.path && touched.vae?.path ? (
<FormErrorMessage>{errors.vae?.path}</FormErrorMessage>
@ -245,7 +235,6 @@ export default function AddDiffusersModel() {
id="vae.repo_id"
name="vae.repo_id"
type="text"
width="2xl"
/>
{!!errors.vae?.repo_id && touched.vae?.repo_id ? (
<FormErrorMessage>{errors.vae?.repo_id}</FormErrorMessage>

View File

@ -1,42 +1,85 @@
import { Flex, Radio, RadioGroup, Text, Tooltip } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
Flex,
Radio,
RadioGroup,
Text,
Tooltip,
useColorMode,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAISelect from 'common/components/IAISelect';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import IAISimpleCheckbox from 'common/components/IAISimpleCheckbox';
import IAISlider from 'common/components/IAISlider';
import { addToast } from 'features/system/store/systemSlice';
import { pickBy } from 'lodash-es';
import { useState } from 'react';
import { useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import {
useGetMainModelsQuery,
useMergeMainModelsMutation,
} from 'services/api/endpoints/models';
import { BaseModelType, MergeModelConfig } from 'services/api/types';
import { mode } from 'theme/util/mode';
const baseModelTypeSelectData = [
{ label: 'Stable Diffusion 1', value: 'sd-1' },
{ label: 'Stable Diffusion 2', value: 'sd-2' },
];
type MergeInterpolationMethods =
| 'weighted_sum'
| 'sigmoid'
| 'inv_sigmoid'
| 'add_difference';
export default function MergeModelsPanel() {
const { t } = useTranslation();
const { colorMode } = useColorMode();
const dispatch = useAppDispatch();
const { data } = useGetMainModelsQuery();
const diffusersModels = pickBy(
const [mergeModels, { isLoading }] = useMergeMainModelsMutation();
const [baseModel, setBaseModel] = useState<BaseModelType>('sd-1');
const sd1DiffusersModels = pickBy(
data?.entities,
(value, _) => value?.model_format === 'diffusers'
(value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-1'
);
const [modelOne, setModelOne] = useState<string>(
Object.keys(diffusersModels)[0]
const sd2DiffusersModels = pickBy(
data?.entities,
(value, _) =>
value?.model_format === 'diffusers' && value?.base_model === 'sd-2'
);
const [modelTwo, setModelTwo] = useState<string>(
Object.keys(diffusersModels)[1]
const modelsMap = useMemo(() => {
return {
'sd-1': sd1DiffusersModels,
'sd-2': sd2DiffusersModels,
};
}, [sd1DiffusersModels, sd2DiffusersModels]);
const [modelOne, setModelOne] = useState<string | null>(
Object.keys(modelsMap[baseModel])[0]
);
const [modelThree, setModelThree] = useState<string>('none');
const [modelTwo, setModelTwo] = useState<string | null>(
Object.keys(modelsMap[baseModel])[1]
);
const [modelThree, setModelThree] = useState<string | null>(null);
const [mergedModelName, setMergedModelName] = useState<string>('');
const [modelMergeAlpha, setModelMergeAlpha] = useState<number>(0.5);
const [modelMergeInterp, setModelMergeInterp] = useState<
'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
>('weighted_sum');
const [modelMergeInterp, setModelMergeInterp] =
useState<MergeInterpolationMethods>('weighted_sum');
const [modelMergeSaveLocType, setModelMergeSaveLocType] = useState<
'root' | 'custom'
@ -47,41 +90,73 @@ export default function MergeModelsPanel() {
const [modelMergeForce, setModelMergeForce] = useState<boolean>(false);
const modelOneList = Object.keys(diffusersModels).filter(
(model) => model !== modelTwo && model !== modelThree
const modelOneList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelTwo && model !== modelThree);
const modelTwoList = Object.keys(
modelsMap[baseModel as keyof typeof modelsMap]
).filter((model) => model !== modelOne && model !== modelThree);
const modelThreeList = Object.keys(modelsMap[baseModel]).filter(
(model) => model !== modelOne && model !== modelTwo
);
const modelTwoList = Object.keys(diffusersModels).filter(
(model) => model !== modelOne && model !== modelThree
);
const modelThreeList = [
{ key: t('modelManager.none'), value: 'none' },
...Object.keys(diffusersModels)
.filter((model) => model !== modelOne && model !== modelTwo)
.map((model) => ({ key: model, value: model })),
];
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const handleBaseModelChange = (v: string) => {
setBaseModel(v as BaseModelType);
setModelOne(null);
setModelTwo(null);
};
const mergeModelsHandler = () => {
let modelsToMerge: string[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== 'none');
const models_names: string[] = [];
const mergeModelsInfo: InvokeAI.InvokeModelMergingProps = {
models_to_merge: modelsToMerge,
let modelsToMerge: (string | null)[] = [modelOne, modelTwo, modelThree];
modelsToMerge = modelsToMerge.filter((model) => model !== null);
modelsToMerge.forEach((model) => {
if (model) {
models_names.push(model?.split('/')[2]);
}
});
const mergeModelsInfo: MergeModelConfig = {
model_names: models_names,
merged_model_name:
mergedModelName !== '' ? mergedModelName : modelsToMerge.join('-'),
mergedModelName !== '' ? mergedModelName : models_names.join('-'),
alpha: modelMergeAlpha,
interp: modelMergeInterp,
model_merge_save_path:
modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
// model_merge_save_path:
// modelMergeSaveLocType === 'root' ? null : modelMergeCustomSaveLoc,
force: modelMergeForce,
};
dispatch(mergeDiffusersModels(mergeModelsInfo));
mergeModels({
base_model: baseModel,
body: mergeModelsInfo,
})
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMerged'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: t('modelManager.modelsMergeFailed'),
status: 'error',
})
)
);
}
});
};
return (
@ -90,7 +165,6 @@ export default function MergeModelsPanel() {
sx={{
flexDirection: 'column',
rowGap: 1,
bg: 'base.900',
}}
>
<Text>{t('modelManager.modelMergeHeaderHelp1')}</Text>
@ -98,26 +172,43 @@ export default function MergeModelsPanel() {
{t('modelManager.modelMergeHeaderHelp2')}
</Text>
</Flex>
<Flex columnGap={4}>
<IAISelect
<IAIMantineSelect
label="Model Type"
w="100%"
data={baseModelTypeSelectData}
value={baseModel}
onChange={handleBaseModelChange}
/>
<IAIMantineSearchableSelect
label={t('modelManager.modelOne')}
validValues={modelOneList}
onChange={(e) => setModelOne(e.target.value)}
w="100%"
value={modelOne}
placeholder={t('modelManager.selectModel')}
data={modelOneList}
onChange={(v) => setModelOne(v)}
/>
<IAISelect
<IAIMantineSearchableSelect
label={t('modelManager.modelTwo')}
validValues={modelTwoList}
onChange={(e) => setModelTwo(e.target.value)}
w="100%"
placeholder={t('modelManager.selectModel')}
value={modelTwo}
data={modelTwoList}
onChange={(v) => setModelTwo(v)}
/>
<IAISelect
<IAIMantineSearchableSelect
label={t('modelManager.modelThree')}
validValues={modelThreeList}
onChange={(e) => {
if (e.target.value !== 'none') {
setModelThree(e.target.value);
data={modelThreeList}
w="100%"
placeholder={t('modelManager.selectModel')}
clearable
onChange={(v) => {
if (!v) {
setModelThree(null);
setModelMergeInterp('add_difference');
} else {
setModelThree('none');
setModelThree(v);
setModelMergeInterp('weighted_sum');
}
}}
@ -136,7 +227,7 @@ export default function MergeModelsPanel() {
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
bg: mode('base.100', 'base.800')(colorMode),
}}
>
<IAISlider
@ -161,7 +252,7 @@ export default function MergeModelsPanel() {
padding: 4,
borderRadius: 'base',
gap: 4,
bg: 'base.900',
bg: mode('base.100', 'base.800')(colorMode),
}}
>
<Text fontWeight={500} fontSize="sm" variant="subtext">
@ -169,12 +260,10 @@ export default function MergeModelsPanel() {
</Text>
<RadioGroup
value={modelMergeInterp}
onChange={(
v: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'
) => setModelMergeInterp(v)}
onChange={(v: MergeInterpolationMethods) => setModelMergeInterp(v)}
>
<Flex columnGap={4}>
{modelThree === 'none' ? (
{modelThree === null ? (
<>
<Radio value="weighted_sum">
<Text fontSize="sm">{t('modelManager.weightedSum')}</Text>
@ -199,7 +288,7 @@ export default function MergeModelsPanel() {
</RadioGroup>
</Flex>
<Flex
{/* <Flex
sx={{
flexDirection: 'column',
padding: 4,
@ -235,7 +324,7 @@ export default function MergeModelsPanel() {
onChange={(e) => setModelMergeCustomSaveLoc(e.target.value)}
/>
)}
</Flex>
</Flex> */}
<IAISimpleCheckbox
label={t('modelManager.ignoreMismatch')}
@ -246,10 +335,8 @@ export default function MergeModelsPanel() {
<IAIButton
onClick={mergeModelsHandler}
isLoading={isProcessing}
isDisabled={
modelMergeSaveLocType === 'custom' && modelMergeCustomSaveLoc === ''
}
isLoading={isLoading}
isDisabled={modelOne === null || modelTwo === null}
>
{t('modelManager.merge')}
</IAIButton>

View File

@ -1,44 +1,60 @@
import { Flex } from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { Flex, Text } from '@chakra-ui/react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
import { useState } from 'react';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit';
import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit';
import ModelList from './ModelManagerPanel/ModelList';
export default function ModelManagerPanel() {
const { data: mainModels } = useGetMainModelsQuery();
const [selectedModelId, setSelectedModelId] = useState<string>();
const { model } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
model: selectedModelId ? data?.entities[selectedModelId] : undefined,
}),
});
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const renderModelEditTabs = () => {
if (!openModel || !mainModels) return;
if (mainModels['entities'][openModel]['model_format'] === 'diffusers') {
return (
<DiffusersModelEdit
modelToEdit={openModel}
retrievedModel={mainModels['entities'][openModel]}
key={openModel}
/>
);
} else {
return (
<CheckpointModelEdit
modelToEdit={openModel}
retrievedModel={mainModels['entities'][openModel]}
key={openModel}
/>
);
}
};
return (
<Flex width="100%" columnGap={8}>
<ModelList />
{renderModelEditTabs()}
<Flex sx={{ gap: 8, w: 'full', h: 'full' }}>
<ModelList
selectedModelId={selectedModelId}
setSelectedModelId={setSelectedModelId}
/>
<ModelEdit model={model} />
</Flex>
);
}
type ModelEditProps = {
model: MainModelConfigEntity | undefined;
};
const ModelEdit = (props: ModelEditProps) => {
const { model } = props;
if (model?.model_format === 'checkpoint') {
return <CheckpointModelEdit key={model.id} model={model} />;
}
if (model?.model_format === 'diffusers') {
return <DiffusersModelEdit key={model.id} model={model} />;
}
return (
<Flex
sx={{
w: 'full',
h: 'full',
justifyContent: 'center',
alignItems: 'center',
maxH: 96,
userSelect: 'none',
}}
>
<Text variant="subtext">No Model Selected</Text>
</Flex>
);
};

View File

@ -1,17 +1,20 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useForm } from '@mantine/form';
import { useTranslation } from 'react-i18next';
import type { RootState } from 'app/store/store';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
CheckpointModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
import ModelConvert from './ModelConvert';
const baseModelSelectData = [
@ -25,55 +28,92 @@ const variantSelectData = [
{ value: 'depth', label: 'Depth' },
];
export type CheckpointModel =
| S<'StableDiffusion1ModelCheckpointConfig'>
| S<'StableDiffusion2ModelCheckpointConfig'>;
type CheckpointModelEditProps = {
modelToEdit: string;
retrievedModel: CheckpointModel;
model: CheckpointModelConfigEntity;
};
export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const isBusy = useAppSelector(selectIsBusy);
const { modelToEdit, retrievedModel } = props;
const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const checkpointEditForm = useForm({
const checkpointEditForm = useForm<CheckpointModelConfig>({
initialValues: {
name: retrievedModel.model_name,
base_model: retrievedModel.base_model,
type: 'main',
path: retrievedModel.path,
description: retrievedModel.description,
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: 'checkpoint',
vae: retrievedModel.vae,
config: retrievedModel.config,
variant: retrievedModel.variant,
vae: model.vae ? model.vae : '',
config: model.config ? model.config : '',
variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
});
const editModelFormSubmitHandler = (values) => {
console.log(values);
};
const editModelFormSubmitHandler = useCallback(
(values: CheckpointModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
checkpointEditForm.setValues(payload as CheckpointModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
checkpointEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
checkpointEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? (
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex justifyContent="space-between" alignItems="center">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name}
{model.model_name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
<ModelConvert model={retrievedModel} />
<ModelConvert model={model} />
</Flex>
<Divider />
@ -88,11 +128,7 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...checkpointEditForm.getInputProps('name')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.description')}
{...checkpointEditForm.getInputProps('description')}
/>
@ -106,36 +142,28 @@ export default function CheckpointModelEdit(props: CheckpointModelEditProps) {
data={variantSelectData}
{...checkpointEditForm.getInputProps('variant')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.modelLocation')}
{...checkpointEditForm.getInputProps('path')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.vaeLocation')}
{...checkpointEditForm.getInputProps('vae')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.config')}
{...checkpointEditForm.getInputProps('config')}
/>
<IAIButton disabled={isProcessing} type="submit">
<IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={500}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -1,25 +1,23 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Divider, Flex, Text } from '@chakra-ui/react';
// import { addNewModel } from 'app/socketio/actions';
import { useTranslation } from 'react-i18next';
import { useForm } from '@mantine/form';
import type { RootState } from 'app/store/store';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import IAIMantineTextInput from 'common/components/IAIMantineInput';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { S } from 'services/api/types';
type DiffusersModel =
| S<'StableDiffusion1ModelDiffusersConfig'>
| S<'StableDiffusion2ModelDiffusersConfig'>;
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { addToast } from 'features/system/store/systemSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
DiffusersModelConfigEntity,
useUpdateMainModelsMutation,
} from 'services/api/endpoints/models';
import { DiffusersModelConfig } from 'services/api/types';
type DiffusersModelEditProps = {
modelToEdit: string;
retrievedModel: DiffusersModel;
model: DiffusersModelConfigEntity;
};
const baseModelSelectData = [
@ -34,39 +32,82 @@ const variantSelectData = [
];
export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
const isProcessing = useAppSelector(
(state: RootState) => state.system.isProcessing
);
const { retrievedModel, modelToEdit } = props;
const isBusy = useAppSelector(selectIsBusy);
const { model } = props;
const [updateMainModel, { isLoading }] = useUpdateMainModelsMutation();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const diffusersEditForm = useForm({
const diffusersEditForm = useForm<DiffusersModelConfig>({
initialValues: {
name: retrievedModel.model_name,
base_model: retrievedModel.base_model,
type: 'main',
path: retrievedModel.path,
description: retrievedModel.description,
model_name: model.model_name ? model.model_name : '',
base_model: model.base_model,
model_type: 'main',
path: model.path ? model.path : '',
description: model.description ? model.description : '',
model_format: 'diffusers',
vae: retrievedModel.vae,
variant: retrievedModel.variant,
vae: model.vae ? model.vae : '',
variant: model.variant,
},
validate: {
path: (value) =>
value.trim().length === 0 ? 'Must provide a path' : null,
},
});
const editModelFormSubmitHandler = (values) => {
console.log(values);
};
const editModelFormSubmitHandler = useCallback(
(values: DiffusersModelConfig) => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: values,
};
updateMainModel(responseBody)
.unwrap()
.then((payload) => {
diffusersEditForm.setValues(payload as DiffusersModelConfig);
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((error) => {
diffusersEditForm.reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
},
[
diffusersEditForm,
dispatch,
model.base_model,
model.model_name,
t,
updateMainModel,
]
);
return modelToEdit ? (
return (
<Flex flexDirection="column" rowGap={4} width="100%">
<Flex flexDirection="column">
<Text fontSize="lg" fontWeight="bold">
{retrievedModel.model_name}
{model.model_name}
</Text>
<Text fontSize="sm" color="base.400">
{MODEL_TYPE_MAP[retrievedModel.base_model]} Model
{MODEL_TYPE_MAP[model.base_model]} Model
</Text>
</Flex>
<Divider />
@ -77,11 +118,7 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
)}
>
<Flex flexDirection="column" overflowY="scroll" gap={4}>
<IAIInput
label={t('modelManager.name')}
{...diffusersEditForm.getInputProps('name')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.description')}
{...diffusersEditForm.getInputProps('description')}
/>
@ -95,31 +132,23 @@ export default function DiffusersModelEdit(props: DiffusersModelEditProps) {
data={variantSelectData}
{...diffusersEditForm.getInputProps('variant')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.modelLocation')}
{...diffusersEditForm.getInputProps('path')}
/>
<IAIInput
<IAIMantineTextInput
label={t('modelManager.vaeLocation')}
{...diffusersEditForm.getInputProps('vae')}
/>
<IAIButton disabled={isProcessing} type="submit">
<IAIButton
type="submit"
isDisabled={isBusy || isLoading}
isLoading={isLoading}
>
{t('modelManager.updateModel')}
</IAIButton>
</Flex>
</form>
</Flex>
) : (
<Flex
sx={{
width: '100%',
justifyContent: 'center',
alignItems: 'center',
borderRadius: 'base',
bg: 'base.900',
}}
>
<Text fontWeight={'500'}>Pick A Model To Edit</Text>
</Flex>
);
}

View File

@ -1,23 +1,18 @@
import {
Flex,
ListItem,
Radio,
RadioGroup,
Text,
Tooltip,
UnorderedList,
} from '@chakra-ui/react';
import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions';
import { makeToast } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { CheckpointModel } from './CheckpointModelEdit';
import { useConvertMainModelsMutation } from 'services/api/endpoints/models';
import { CheckpointModelConfig } from 'services/api/types';
interface ModelConvertProps {
model: CheckpointModel;
model: CheckpointModelConfig;
}
export default function ModelConvert(props: ModelConvertProps) {
@ -26,6 +21,8 @@ export default function ModelConvert(props: ModelConvertProps) {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same');
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
@ -38,20 +35,39 @@ export default function ModelConvert(props: ModelConvertProps) {
};
const modelConvertHandler = () => {
const modelToConvert = {
model_name: model,
save_location: saveLocation,
custom_location:
saveLocation === 'custom' && customSaveLocation !== ''
? customSaveLocation
: null,
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
};
dispatch(convertToDiffusers(modelToConvert));
convertModel(responseBody)
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConverted')}: ${model.model_name}`,
status: 'success',
})
)
);
})
.catch((_) => {
dispatch(
addToast(
makeToast({
title: `${t('modelManager.modelConversionFailed')}: ${
model.model_name
}`,
status: 'error',
})
)
);
});
};
return (
<IAIAlertDialog
title={`${t('modelManager.convert')} ${model.name}`}
title={`${t('modelManager.convert')} ${model.model_name}`}
acceptCallback={modelConvertHandler}
cancelCallback={modelConvertCancelHandler}
acceptButtonText={`${t('modelManager.convert')}`}
@ -60,6 +76,7 @@ export default function ModelConvert(props: ModelConvertProps) {
size={'sm'}
aria-label={t('modelManager.convertToDiffusers')}
className=" modal-close-btn"
isLoading={isLoading}
>
🧨 {t('modelManager.convertToDiffusers')}
</IAIButton>
@ -77,7 +94,7 @@ export default function ModelConvert(props: ModelConvertProps) {
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex>
<Flex flexDir="column" gap={4}>
{/* <Flex flexDir="column" gap={4}>
<Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')}
@ -103,9 +120,9 @@ export default function ModelConvert(props: ModelConvertProps) {
</Radio>
</Flex>
</RadioGroup>
</Flex>
</Flex> */}
{saveLocation === 'custom' && (
{/* {saveLocation === 'custom' && (
<Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.customSaveLocation')}
@ -119,8 +136,7 @@ export default function ModelConvert(props: ModelConvertProps) {
width="full"
/>
</Flex>
)}
</Flex>
)} */}
</IAIAlertDialog>
);
}

View File

@ -1,185 +1,46 @@
import { Box, Flex, Spinner, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import { EntityState } from '@reduxjs/toolkit';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { forEach } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useGetMainModelsQuery,
} from 'services/api/endpoints/models';
import ModelListItem from './ModelListItem';
import { useTranslation } from 'react-i18next';
type ModelListProps = {
selectedModelId: string | undefined;
setSelectedModelId: (name: string | undefined) => void;
};
import type { ChangeEvent, ReactNode } from 'react';
import React, { useMemo, useState, useTransition } from 'react';
import { useGetMainModelsQuery } from 'services/api/endpoints/models';
function ModelFilterButton({
label,
isActive,
onClick,
}: {
label: string;
isActive: boolean;
onClick: () => void;
}) {
return (
<IAIButton
onClick={onClick}
isActive={isActive}
sx={{
_active: {
bg: 'accent.750',
},
}}
size="sm"
>
{label}
</IAIButton>
);
}
const ModelList = () => {
const { data: mainModels } = useGetMainModelsQuery();
const [renderModelList, setRenderModelList] = React.useState<boolean>(false);
React.useEffect(() => {
const timer = setTimeout(() => {
setRenderModelList(true);
}, 200);
return () => clearTimeout(timer);
}, []);
const [searchText, setSearchText] = useState<string>('');
const [isSelectedFilter, setIsSelectedFilter] = useState<
'all' | 'ckpt' | 'diffusers'
>('all');
const [_, startTransition] = useTransition();
type ModelFormat = 'all' | 'checkpoint' | 'diffusers';
const ModelList = (props: ModelListProps) => {
const { selectedModelId, setSelectedModelId } = props;
const { t } = useTranslation();
const [nameFilter, setNameFilter] = useState<string>('');
const [modelFormatFilter, setModelFormatFilter] =
useState<ModelFormat>('all');
const handleSearchFilter = (e: ChangeEvent<HTMLInputElement>) => {
startTransition(() => {
setSearchText(e.target.value);
});
};
const { filteredDiffusersModels } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredDiffusersModels: modelsFilter(data, 'diffusers', nameFilter),
}),
});
const renderModelListItems = useMemo(() => {
const ckptModelListItemsToRender: ReactNode[] = [];
const diffusersModelListItemsToRender: ReactNode[] = [];
const filteredModelListItemsToRender: ReactNode[] = [];
const localFilteredModelListItemsToRender: ReactNode[] = [];
const { filteredCheckpointModels } = useGetMainModelsQuery(undefined, {
selectFromResult: ({ data }) => ({
filteredCheckpointModels: modelsFilter(data, 'checkpoint', nameFilter),
}),
});
if (!mainModels) return;
const modelList = mainModels.entities;
Object.keys(modelList).forEach((model, i) => {
if (
modelList[model]?.model_name
.toLowerCase()
.includes(searchText.toLowerCase())
) {
filteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
if (modelList[model]?.model_format === isSelectedFilter) {
localFilteredModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
}
}
if (modelList[model]?.model_format !== 'diffusers') {
ckptModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
} else {
diffusersModelListItemsToRender.push(
<ModelListItem
key={i}
modelKey={model}
name={modelList[model]?.model_name}
description={modelList[model].description}
/>
);
}
});
return searchText !== '' ? (
isSelectedFilter === 'all' ? (
<Box marginTop={4}>{filteredModelListItemsToRender}</Box>
) : (
<Box marginTop={4}>{localFilteredModelListItemsToRender}</Box>
)
) : (
<Flex flexDirection="column" rowGap={6}>
{isSelectedFilter === 'all' && (
<>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
mb: 4,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.diffusersModels')}
</Text>
{diffusersModelListItemsToRender}
</Box>
<Box>
<Text
sx={{
fontWeight: '500',
py: 2,
px: 4,
my: 4,
mx: 0,
borderRadius: 'base',
width: 'max-content',
fontSize: 'sm',
bg: 'base.750',
}}
>
{t('modelManager.checkpointModels')}
</Text>
{ckptModelListItemsToRender}
</Box>
</>
)}
{isSelectedFilter === 'diffusers' && (
<Flex flexDirection="column" marginTop={4}>
{diffusersModelListItemsToRender}
</Flex>
)}
{isSelectedFilter === 'ckpt' && (
<Flex flexDirection="column" marginTop={4}>
{ckptModelListItemsToRender}
</Flex>
)}
</Flex>
);
}, [mainModels, searchText, t, isSelectedFilter]);
const handleSearchFilter = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setNameFilter(e.target.value);
}, []);
return (
<Flex flexDirection="column" rowGap={4} width="50%" minWidth="50%">
@ -187,7 +48,6 @@ const ModelList = () => {
onChange={handleSearchFilter}
label={t('modelManager.search')}
/>
<Flex
flexDirection="column"
gap={4}
@ -195,39 +55,89 @@ const ModelList = () => {
overflow="scroll"
paddingInlineEnd={4}
>
<Flex columnGap={2}>
<ModelFilterButton
label={t('modelManager.allModels')}
onClick={() => setIsSelectedFilter('all')}
isActive={isSelectedFilter === 'all'}
/>
<ModelFilterButton
label={t('modelManager.diffusersModels')}
onClick={() => setIsSelectedFilter('diffusers')}
isActive={isSelectedFilter === 'diffusers'}
/>
<ModelFilterButton
label={t('modelManager.checkpointModels')}
onClick={() => setIsSelectedFilter('ckpt')}
isActive={isSelectedFilter === 'ckpt'}
/>
</Flex>
{renderModelList ? (
renderModelListItems
) : (
<Flex
width="100%"
minHeight={96}
justifyContent="center"
alignItems="center"
<ButtonGroup isAttached>
<IAIButton
onClick={() => setModelFormatFilter('all')}
isChecked={modelFormatFilter === 'all'}
size="sm"
>
<Spinner />
</Flex>
)}
{t('modelManager.allModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('diffusers')}
isChecked={modelFormatFilter === 'diffusers'}
>
{t('modelManager.diffusersModels')}
</IAIButton>
<IAIButton
size="sm"
onClick={() => setModelFormatFilter('checkpoint')}
isChecked={modelFormatFilter === 'checkpoint'}
>
{t('modelManager.checkpointModels')}
</IAIButton>
</ButtonGroup>
{['all', 'diffusers'].includes(modelFormatFilter) &&
filteredDiffusersModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Diffusers
</Text>
{filteredDiffusersModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
)}
{['all', 'checkpoint'].includes(modelFormatFilter) &&
filteredCheckpointModels.length > 0 && (
<Flex sx={{ gap: 2, flexDir: 'column' }}>
<Text variant="subtext" fontSize="sm">
Checkpoint
</Text>
{filteredCheckpointModels.map((model) => (
<ModelListItem
key={model.id}
model={model}
isSelected={selectedModelId === model.id}
setSelectedModelId={setSelectedModelId}
/>
))}
</Flex>
)}
</Flex>
</Flex>
);
};
export default ModelList;
const modelsFilter = (
data: EntityState<MainModelConfigEntity> | undefined,
model_format: ModelFormat,
nameFilter: string
) => {
const filteredModels: MainModelConfigEntity[] = [];
forEach(data?.entities, (model) => {
if (!model) {
return;
}
const matchesFilter = model.model_name
.toLowerCase()
.includes(nameFilter.toLowerCase());
const matchesFormat = model.model_format === model_format;
if (matchesFilter && matchesFormat) {
filteredModels.push(model);
}
});
return filteredModels;
};

View File

@ -1,98 +1,89 @@
import { DeleteIcon, EditIcon } from '@chakra-ui/icons';
import { Box, Flex, Spacer, Text, Tooltip } from '@chakra-ui/react';
// import { deleteModel, requestModelChange } from 'app/socketio/actions';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { DeleteIcon } from '@chakra-ui/icons';
import { Flex, Text, Tooltip } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIIconButton from 'common/components/IAIIconButton';
import { setOpenModel } from 'features/system/store/systemSlice';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import {
MainModelConfigEntity,
useDeleteMainModelsMutation,
} from 'services/api/endpoints/models';
type ModelListItemProps = {
modelKey: string;
name: string;
description: string | undefined;
model: MainModelConfigEntity;
isSelected: boolean;
setSelectedModelId: (v: string | undefined) => void;
};
export default function ModelListItem(props: ModelListItemProps) {
const { isProcessing, isConnected } = useAppSelector(
(state: RootState) => state.system
);
const openModel = useAppSelector(
(state: RootState) => state.system.openModel
);
const isBusy = useAppSelector(selectIsBusy);
const { t } = useTranslation();
const [deleteMainModel] = useDeleteMainModelsMutation();
const dispatch = useAppDispatch();
const { model, isSelected, setSelectedModelId } = props;
const { modelKey, name, description } = props;
const handleSelectModel = useCallback(() => {
setSelectedModelId(model.id);
}, [model.id, setSelectedModelId]);
const openModelHandler = () => {
dispatch(setOpenModel(modelKey));
};
const handleModelDelete = () => {
dispatch(deleteModel(modelKey));
dispatch(setOpenModel(null));
};
const handleModelDelete = useCallback(() => {
deleteMainModel(model);
setSelectedModelId(undefined);
}, [deleteMainModel, model, setSelectedModelId]);
return (
<Flex
alignItems="center"
p={2}
borderRadius="base"
sx={
modelKey === openModel
? {
bg: 'accent.750',
_hover: {
bg: 'accent.750',
},
}
: {
_hover: {
bg: 'base.750',
},
}
}
>
<Box onClick={openModelHandler} cursor="pointer">
<Tooltip label={description} hasArrow placement="bottom">
<Text fontWeight="600">{name}</Text>
<Flex sx={{ gap: 2, alignItems: 'center', w: 'full' }}>
<Flex
as={IAIButton}
isChecked={isSelected}
sx={{
justifyContent: 'start',
p: 2,
borderRadius: 'base',
w: 'full',
alignItems: 'center',
bg: isSelected ? 'accent.400' : 'base.100',
color: isSelected ? 'base.50' : 'base.800',
_hover: {
bg: isSelected ? 'accent.500' : 'base.200',
color: isSelected ? 'base.50' : 'base.800',
},
_dark: {
color: isSelected ? 'base.50' : 'base.100',
bg: isSelected ? 'accent.600' : 'base.850',
_hover: {
color: isSelected ? 'base.50' : 'base.100',
bg: isSelected ? 'accent.550' : 'base.800',
},
},
}}
onClick={handleSelectModel}
>
<Tooltip label={model.description} hasArrow placement="bottom">
<Text sx={{ fontWeight: 500 }}>{model.model_name}</Text>
</Tooltip>
</Box>
<Spacer onClick={openModelHandler} cursor="pointer" />
<Flex gap={2} alignItems="center">
<IAIIconButton
icon={<EditIcon />}
size="sm"
onClick={openModelHandler}
aria-label={t('accessibility.modifyConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected}
/>
<IAIAlertDialog
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
triggerComponent={
<IAIIconButton
icon={<DeleteIcon />}
size="sm"
aria-label={t('modelManager.deleteConfig')}
isDisabled={status === 'active' || isProcessing || !isConnected}
colorScheme="error"
/>
}
>
<Flex rowGap={4} flexDirection="column">
<p style={{ fontWeight: 'bold' }}>{t('modelManager.deleteMsg1')}</p>
<p>{t('modelManager.deleteMsg2')}</p>
</Flex>
</IAIAlertDialog>
</Flex>
<IAIAlertDialog
title={t('modelManager.deleteModel')}
acceptCallback={handleModelDelete}
acceptButtonText={t('modelManager.delete')}
triggerComponent={
<IAIIconButton
icon={<DeleteIcon />}
aria-label={t('modelManager.deleteConfig')}
isDisabled={isBusy}
colorScheme="error"
/>
}
>
<Flex rowGap={4} flexDirection="column">
<p style={{ fontWeight: 'bold' }}>{t('modelManager.deleteMsg1')}</p>
<p>{t('modelManager.deleteMsg2')}</p>
</Flex>
</IAIAlertDialog>
</Flex>
);
}

View File

@ -9,17 +9,15 @@ import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useLayoutEffect } from 'react';
import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBeta';
import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta';
import { ImageDTO } from 'services/api/types';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import {
CanvasInitialImageDropData,
isValidDrop,
useDroppable,
} from 'app/components/ImageDnd/typesafeDnd';
import IAIDropOverlay from 'common/components/IAIDropOverlay';
import { memo, useLayoutEffect } from 'react';
import UnifiedCanvasToolSettingsBeta from './UnifiedCanvasBeta/UnifiedCanvasToolSettingsBeta';
import UnifiedCanvasToolbarBeta from './UnifiedCanvasBeta/UnifiedCanvasToolbarBeta';
const selector = createSelector(
[canvasSelector, uiSelector],

View File

@ -0,0 +1,140 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
export const useMantineMultiSelectStyles = () => {
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useCallback(
() => ({
label: {
color: mode(base700, base300)(colorMode),
},
separatorLabel: {
color: mode(base500, base500)(colorMode),
'::after': { borderTopColor: mode(base300, base700)(colorMode) },
},
searchInput: {
':placeholder': {
color: mode(base300, base700)(colorMode),
},
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 24,
padding: 20,
button: {
color: mode(base900, base100)(colorMode),
},
},
}),
[
accent200,
accent300,
accent400,
accent500,
accent600,
base100,
base200,
base300,
base400,
base50,
base500,
base600,
base700,
base800,
base900,
boxShadow,
colorMode,
]
);
return styles;
};

View File

@ -0,0 +1,134 @@
import { useColorMode, useToken } from '@chakra-ui/react';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { useCallback } from 'react';
import { mode } from 'theme/util/mode';
export const useMantineSelectStyles = () => {
const {
base50,
base100,
base200,
base300,
base400,
base500,
base600,
base700,
base800,
base900,
accent200,
accent300,
accent400,
accent500,
accent600,
} = useChakraThemeTokens();
const { colorMode } = useColorMode();
const [boxShadow] = useToken('shadows', ['dark-lg']);
const styles = useCallback(
() => ({
label: {
color: mode(base700, base300)(colorMode),
},
separatorLabel: {
color: mode(base500, base500)(colorMode),
'::after': { borderTopColor: mode(base300, base700)(colorMode) },
},
input: {
backgroundColor: mode(base50, base900)(colorMode),
borderWidth: '2px',
borderColor: mode(base200, base800)(colorMode),
color: mode(base900, base100)(colorMode),
paddingRight: 24,
fontWeight: 600,
'&:hover': { borderColor: mode(base300, base600)(colorMode) },
'&:focus': {
borderColor: mode(accent300, accent600)(colorMode),
},
'&:is(:focus, :hover)': {
borderColor: mode(base400, base500)(colorMode),
},
'&:focus-within': {
borderColor: mode(accent200, accent600)(colorMode),
},
'&[data-disabled]': {
backgroundColor: mode(base300, base700)(colorMode),
color: mode(base600, base400)(colorMode),
cursor: 'not-allowed',
},
},
value: {
backgroundColor: mode(base100, base900)(colorMode),
color: mode(base900, base100)(colorMode),
button: {
color: mode(base900, base100)(colorMode),
},
'&:hover': {
backgroundColor: mode(base300, base700)(colorMode),
cursor: 'pointer',
},
},
dropdown: {
backgroundColor: mode(base200, base800)(colorMode),
borderColor: mode(base200, base800)(colorMode),
boxShadow,
},
item: {
backgroundColor: mode(base200, base800)(colorMode),
color: mode(base800, base200)(colorMode),
padding: 6,
'&[data-hovered]': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
'&[data-active]': {
backgroundColor: mode(base300, base700)(colorMode),
'&:hover': {
color: mode(base900, base100)(colorMode),
backgroundColor: mode(base300, base700)(colorMode),
},
},
'&[data-selected]': {
backgroundColor: mode(accent400, accent600)(colorMode),
color: mode(base50, base100)(colorMode),
fontWeight: 600,
'&:hover': {
backgroundColor: mode(accent500, accent500)(colorMode),
color: mode('white', base50)(colorMode),
},
},
'&[data-disabled]': {
color: mode(base500, base600)(colorMode),
cursor: 'not-allowed',
},
},
rightSection: {
width: 32,
button: {
color: mode(base900, base100)(colorMode),
},
},
}),
[
accent200,
accent300,
accent400,
accent500,
accent600,
base100,
base200,
base300,
base400,
base50,
base500,
base600,
base700,
base800,
base900,
boxShadow,
colorMode,
]
);
return styles;
};

View File

@ -1,23 +1,31 @@
import { MantineThemeOverride } from '@mantine/core';
import { useMemo } from 'react';
export const mantineTheme: MantineThemeOverride = {
colorScheme: 'dark',
fontFamily: `'Inter Variable', sans-serif`,
components: {
ScrollArea: {
defaultProps: {
scrollbarSize: 10,
},
styles: {
scrollbar: {
'&:hover': {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
export const useMantineTheme = () => {
const mantineTheme: MantineThemeOverride = useMemo(
() => ({
colorScheme: 'dark',
fontFamily: `'Inter Variable', sans-serif`,
components: {
ScrollArea: {
defaultProps: {
scrollbarSize: 10,
},
styles: {
scrollbar: {
'&:hover': {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
},
},
thumb: {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
},
},
},
thumb: {
backgroundColor: 'var(--invokeai-colors-baseAlpha-300)',
},
},
},
},
}),
[]
);
return mantineTheme;
};

View File

@ -4,7 +4,7 @@ import { paths } from '../schema';
type ListBoardImagesArg =
paths['/api/v1/board_images/{board_id}']['get']['parameters']['path'] &
paths['/api/v1/board_images/{board_id}']['get']['parameters']['query'];
paths['/api/v1/board_images/{board_id}']['get']['parameters']['query'];
type AddImageToBoardArg =
paths['/api/v1/board_images/']['post']['requestBody']['content']['application/json'];
@ -25,11 +25,12 @@ export const boardImagesApi = api.injectEndpoints({
query: ({ board_id, offset, limit }) => ({
url: `board_images/${board_id}`,
method: 'GET',
}),
providesTags: (result, error, arg) => {
// any list of boardimages
const tags: ApiFullTagDescription[] = [{ id: 'BoardImage', type: `${arg.board_id}_${LIST_TAG}` }];
const tags: ApiFullTagDescription[] = [
{ type: 'BoardImage', id: `${arg.board_id}_${LIST_TAG}` },
];
if (result) {
// and individual tags for each boardimage
@ -57,7 +58,7 @@ export const boardImagesApi = api.injectEndpoints({
}),
invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' },
{ type: 'Board', id: arg.board_id }
{ type: 'Board', id: arg.board_id },
],
}),
@ -69,7 +70,7 @@ export const boardImagesApi = api.injectEndpoints({
}),
invalidatesTags: (result, error, arg) => [
{ type: 'BoardImage' },
{ type: 'Board', id: arg.board_id }
{ type: 'Board', id: arg.board_id },
],
}),
}),

View File

@ -20,7 +20,7 @@ export const boardsApi = api.injectEndpoints({
query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }];
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) {
// and individual tags for each board
@ -43,7 +43,7 @@ export const boardsApi = api.injectEndpoints({
}),
providesTags: (result, error, arg) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ id: 'Board', type: LIST_TAG }];
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
if (result) {
// and individual tags for each board
@ -69,7 +69,7 @@ export const boardsApi = api.injectEndpoints({
method: 'POST',
params: { board_name },
}),
invalidatesTags: [{ id: 'Board', type: LIST_TAG }],
invalidatesTags: [{ type: 'Board', id: LIST_TAG }],
}),
updateBoard: build.mutation<BoardDTO, UpdateBoardArg>({
@ -87,8 +87,15 @@ export const boardsApi = api.injectEndpoints({
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }],
}),
deleteBoardAndImages: build.mutation<void, string>({
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE', params: { include_images: true } }),
invalidatesTags: (result, error, arg) => [{ type: 'Board', id: arg }, { type: 'Image', id: LIST_TAG }],
query: (board_id) => ({
url: `boards/${board_id}`,
method: 'DELETE',
params: { include_images: true },
}),
invalidatesTags: (result, error, arg) => [
{ type: 'Board', id: arg },
{ type: 'Image', id: LIST_TAG },
],
}),
}),
});
@ -99,5 +106,5 @@ export const {
useCreateBoardMutation,
useUpdateBoardMutation,
useDeleteBoardMutation,
useDeleteBoardAndImagesMutation
useDeleteBoardAndImagesMutation,
} = boardsApi;

View File

@ -2,16 +2,27 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit';
import { cloneDeep } from 'lodash-es';
import {
AnyModelConfig,
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig,
DiffusersModelConfig,
LoRAModelConfig,
MainModelConfig,
MergeModelConfig,
TextualInversionModelConfig,
VaeModelConfig,
} from 'services/api/types';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { paths } from '../schema';
export type MainModelConfigEntity = MainModelConfig & { id: string };
export type DiffusersModelConfigEntity = DiffusersModelConfig & { id: string };
export type CheckpointModelConfigEntity = CheckpointModelConfig & {
id: string;
};
export type MainModelConfigEntity =
| DiffusersModelConfigEntity
| CheckpointModelConfigEntity;
export type LoRAModelConfigEntity = LoRAModelConfig & { id: string };
@ -32,6 +43,38 @@ type AnyModelConfigEntity =
| TextualInversionModelConfigEntity
| VaeModelConfigEntity;
type UpdateMainModelArg = {
base_model: BaseModelType;
model_name: string;
body: MainModelConfig;
};
type UpdateMainModelResponse =
paths['/api/v1/models/{base_model}/{model_type}/{model_name}']['patch']['responses']['200']['content']['application/json'];
type DeleteMainModelArg = {
base_model: BaseModelType;
model_name: string;
};
type DeleteMainModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
};
type ConvertMainModelResponse =
paths['/api/v1/models/convert/{base_model}/{model_type}/{model_name}']['put']['responses']['200']['content']['application/json'];
type MergeMainModelArg = {
base_model: BaseModelType;
body: MergeModelConfig;
};
type MergeMainModelResponse =
paths['/api/v1/models/merge/{base_model}']['put']['responses']['200']['content']['application/json'];
const mainModelsAdapter = createEntityAdapter<MainModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
@ -76,7 +119,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'main' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'MainModel', type: LIST_TAG },
{ type: 'MainModel', id: LIST_TAG },
];
if (result) {
@ -104,11 +147,58 @@ export const modelsApi = api.injectEndpoints({
);
},
}),
updateMainModels: build.mutation<
UpdateMainModelResponse,
UpdateMainModelArg
>({
query: ({ base_model, model_name, body }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'PATCH',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
deleteMainModels: build.mutation<
DeleteMainModelResponse,
DeleteMainModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/${base_model}/main/${model_name}`,
method: 'DELETE',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
convertMainModels: build.mutation<
ConvertMainModelResponse,
ConvertMainModelArg
>({
query: ({ base_model, model_name }) => {
return {
url: `models/convert/${base_model}/main/${model_name}`,
method: 'PUT',
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
mergeMainModels: build.mutation<MergeMainModelResponse, MergeMainModelArg>({
query: ({ base_model, body }) => {
return {
url: `models/merge/${base_model}`,
method: 'PUT',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'LoRAModel', type: LIST_TAG },
{ type: 'LoRAModel', id: LIST_TAG },
];
if (result) {
@ -143,7 +233,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'ControlNetModel', type: LIST_TAG },
{ type: 'ControlNetModel', id: LIST_TAG },
];
if (result) {
@ -175,7 +265,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'VaeModel', type: LIST_TAG },
{ type: 'VaeModel', id: LIST_TAG },
];
if (result) {
@ -210,7 +300,7 @@ export const modelsApi = api.injectEndpoints({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => {
const tags: ApiFullTagDescription[] = [
{ id: 'TextualInversionModel', type: LIST_TAG },
{ type: 'TextualInversionModel', id: LIST_TAG },
];
if (result) {
@ -247,4 +337,8 @@ export const {
useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useUpdateMainModelsMutation,
useDeleteMainModelsMutation,
useConvertMainModelsMutation,
useMergeMainModelsMutation,
} = modelsApi;

View File

@ -3290,7 +3290,7 @@ export type components = {
/** ModelsList */
ModelsList: {
/** Models */
models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"])[];
models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[];
};
/**
* MultiplyInvocation
@ -4605,18 +4605,18 @@ export type components = {
*/
image?: components["schemas"]["ImageField"];
};
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@ -4997,7 +4997,7 @@ export type operations = {
/** @description The model imported successfully */
201: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
/** @description The model could not be found */
@ -5065,14 +5065,14 @@ export type operations = {
};
requestBody: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
responses: {
/** @description The model was updated successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
/** @description Bad request */
@ -5106,7 +5106,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
/** @description Bad request */
@ -5141,7 +5141,7 @@ export type operations = {
/** @description Model converted successfully */
200: {
content: {
"application/json": components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"];
"application/json": components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"];
};
};
/** @description Incompatible models */

View File

@ -42,17 +42,20 @@ export type ControlNetModelConfig =
components['schemas']['ControlNetModelConfig'];
export type TextualInversionModelConfig =
components['schemas']['TextualInversionModelConfig'];
export type MainModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
export type DiffusersModelConfig =
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig'];
export type CheckpointModelConfig =
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig'];
export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig;
export type AnyModelConfig =
| LoRAModelConfig
| VaeModelConfig
| ControlNetModelConfig
| TextualInversionModelConfig
| MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models'];
// Graphs
export type Graph = components['schemas']['Graph'];

View File

@ -19,16 +19,8 @@ const invokeAI = defineStyle((props) => {
bg: mode('base.200', 'base.600')(props),
color: mode('base.850', 'base.100')(props),
borderRadius: 'base',
textShadow: mode(
'0 0 0.3rem var(--invokeai-colors-base-50)',
'0 0 0.3rem var(--invokeai-colors-base-900)'
)(props),
svg: {
fill: mode('base.850', 'base.100')(props),
filter: mode(
'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-100))',
'drop-shadow(0px 0px 0.3rem var(--invokeai-colors-base-800))'
)(props),
},
_hover: {
bg: mode('base.300', 'base.500')(props),
@ -57,16 +49,8 @@ const invokeAI = defineStyle((props) => {
bg: mode(`${c}.400`, `${c}.600`)(props),
color: mode(`base.50`, `base.100`)(props),
borderRadius: 'base',
textShadow: mode(
`0 0 0.3rem var(--invokeai-colors-${c}-600)`,
`0 0 0.3rem var(--invokeai-colors-${c}-800)`
)(props),
svg: {
fill: mode(`base.50`, `base.100`)(props),
filter: mode(
`drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-600))`,
`drop-shadow(0px 0px 0.3rem var(--invokeai-colors-${c}-800))`
)(props),
},
_disabled,
_hover: {

View File

@ -2,7 +2,7 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react';
import { mode } from '@chakra-ui/theme-tools';
const subtext = defineStyle((props) => ({
color: mode('colors.base.500', 'colors.base.400')(props),
color: mode('base.500', 'base.400')(props),
}));
export const textTheme = defineStyleConfig({