diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 00646e70e3..86a922c05a 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -193,7 +193,10 @@ class ModelInstall(object): models_installed.update(self._install_path(path)) # folders style or similar - elif path.is_dir() and any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): + elif path.is_dir() and any([(path/x).exists() for x in \ + {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} + ] + ): models_installed.update(self._install_path(path)) # recursive scan diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index ae576e39d9..d8ecdf81c2 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -3,15 +3,13 @@ from __future__ import annotations import copy from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union, List import torch from compel.embeddings_provider import BaseTextualInversionManager from diffusers.models import UNet2DConditionModel from safetensors.torch import load_file -from torch.utils.hooks import RemovableHandle -from transformers import CLIPTextModel - +from transformers import CLIPTextModel, CLIPTokenizer class LoRALayerBase: #rank: Optional[int] @@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase): def get_weight(self): if self.mid is not None: - up = self.up.reshape(up.shape[0], up.shape[1]) - down = self.down.reshape(up.shape[0], up.shape[1]) + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) else: weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) @@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module): else: # TODO: diff/ia3/... format print( - f">> Encountered unknown lora layer module in {self.name}: {layer_key}" + f">> Encountered unknown lora layer module in {model.name}: {layer_key}" ) return diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index f15dcfac3c..db8a691d29 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -785,7 +785,7 @@ class ModelManager(object): if path in known_paths or path.parent in scanned_dirs: scanned_dirs.add(path) continue - if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin'}]): + if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): new_models_found.update(installer.heuristic_import(path)) scanned_dirs.add(path) @@ -794,7 +794,8 @@ class ModelManager(object): if path in known_paths or path.parent in scanned_dirs: continue if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: - new_models_found.update(installer.heuristic_import(path)) + import_result = installer.heuristic_import(path) + new_models_found.update(import_result) self.logger.info(f'Scanned {items_scanned} files and directories, imported {len(new_models_found)} models') installed.update(new_models_found) diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 2828cc7ab1..eef3292d6d 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -78,7 +78,6 @@ class ModelProbe(object): format_type = 'diffusers' if model_path.is_dir() else 'checkpoint' else: format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' - model_info = None try: model_type = cls.get_model_type_from_folder(model_path, model) \ @@ -105,7 +104,7 @@ class ModelProbe(object): ) else 512, ) except Exception: - return None + raise return model_info @@ -127,6 +126,8 @@ class ModelProbe(object): return ModelType.Vae elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): return ModelType.Lora + elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): + return ModelType.Lora elif any(key.startswith(v) for v in {"control_model", "input_blocks"}): return ModelType.ControlNet elif key in {"emb_params", "string_to_param"}: @@ -137,7 +138,7 @@ class ModelProbe(object): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): return ModelType.TextualInversion - raise ValueError("Unable to determine model type") + raise ValueError(f"Unable to determine model type for {model_path}") @classmethod def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType: @@ -167,7 +168,7 @@ class ModelProbe(object): return type # give up - raise ValueError("Unable to determine model type") + raise ValueError("Unable to determine model type for {folder_path}") @classmethod def _scan_and_load_checkpoint(cls,model_path: Path)->dict: diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 33ef114912..f3ebcb22be 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -678,9 +678,8 @@ def select_and_download_models(opt: Namespace): # this is where the TUI is called else: - # needed because the torch library is loaded, even though we don't use it - # currently commented out because it has started generating errors (?) - # torch.multiprocessing.set_start_method("spawn") + # needed to support the probe() method running under a subprocess + torch.multiprocessing.set_start_method("spawn") # the third argument is needed in the Windows 11 environment in # order to launch and resize a console window running this program diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx index 97e33f300b..9a0bc865a4 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -1,15 +1,16 @@ import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; -import { memo } from 'react'; +import { RefObject, memo } from 'react'; import { mode } from 'theme/util/mode'; type IAIMultiSelectProps = MultiSelectProps & { tooltip?: string; + inputRef?: RefObject; }; const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { - const { searchable = true, tooltip, ...rest } = props; + const { searchable = true, tooltip, inputRef, ...rest } = props; const { base50, base100, @@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { return ( ({ label: { diff --git a/invokeai/frontend/web/src/features/embedding/components/AddEmbeddingButton.tsx b/invokeai/frontend/web/src/features/embedding/components/AddEmbeddingButton.tsx new file mode 100644 index 0000000000..1dae6f56e6 --- /dev/null +++ b/invokeai/frontend/web/src/features/embedding/components/AddEmbeddingButton.tsx @@ -0,0 +1,33 @@ +import IAIIconButton from 'common/components/IAIIconButton'; +import { memo } from 'react'; +import { BiCode } from 'react-icons/bi'; + +type Props = { + onClick: () => void; +}; + +const AddEmbeddingButton = (props: Props) => { + const { onClick } = props; + return ( + } + sx={{ + p: 2, + color: 'base.700', + _hover: { + color: 'base.550', + }, + _active: { + color: 'base.500', + }, + }} + variant="link" + onClick={onClick} + /> + ); +}; + +export default memo(AddEmbeddingButton); diff --git a/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx new file mode 100644 index 0000000000..3c2ded0166 --- /dev/null +++ b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx @@ -0,0 +1,151 @@ +import { + Flex, + Popover, + PopoverBody, + PopoverContent, + PopoverTrigger, + Text, +} from '@chakra-ui/react'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { forEach } from 'lodash-es'; +import { + PropsWithChildren, + forwardRef, + useCallback, + useMemo, + useRef, +} from 'react'; +import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; +import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants'; + +type EmbeddingSelectItem = { + label: string; + value: string; + description?: string; +}; + +type Props = PropsWithChildren & { + onSelect: (v: string) => void; + isOpen: boolean; + onClose: () => void; +}; + +const ParamEmbeddingPopover = (props: Props) => { + const { onSelect, isOpen, onClose, children } = props; + const { data: embeddingQueryData } = useGetTextualInversionModelsQuery(); + const inputRef = useRef(null); + + const data = useMemo(() => { + if (!embeddingQueryData) { + return []; + } + + const data: EmbeddingSelectItem[] = []; + + forEach(embeddingQueryData.entities, (embedding, _) => { + if (!embedding) return; + + data.push({ + value: embedding.name, + label: embedding.name, + description: embedding.description, + }); + }); + + return data; + }, [embeddingQueryData]); + + const handleChange = useCallback( + (v: string[]) => { + if (v.length === 0) { + return; + } + + onSelect(v[0]); + }, + [onSelect] + ); + + return ( + + {children} + + + {data.length === 0 ? ( + + + No Embeddings Loaded + + + ) : ( + + item.label.toLowerCase().includes(value.toLowerCase().trim()) || + item.value.toLowerCase().includes(value.toLowerCase().trim()) + } + onChange={handleChange} + /> + )} + + + + ); +}; + +export default ParamEmbeddingPopover; + +interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { + value: string; + label: string; + description?: string; +} + +const SelectItem = forwardRef( + ({ label, description, ...others }: ItemProps, ref) => { + return ( +
+
+ {label} + {description && ( + + {description} + + )} +
+
+ ); + } +); + +SelectItem.displayName = 'SelectItem'; diff --git a/invokeai/frontend/web/src/features/embedding/store/embeddingSlice.ts b/invokeai/frontend/web/src/features/embedding/store/embeddingSlice.ts new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx index ea0b3b0fd8..a8d4c84adc 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx @@ -23,6 +23,7 @@ export const makeSelector = (image_name: string) => ({ gallery }) => { const isSelected = gallery.selection.includes(image_name); const selectionCount = gallery.selection.length; + return { isSelected, selectionCount, @@ -117,7 +118,7 @@ const GalleryImage = (props: HoverableImageProps) => { resetIcon={} resetTooltip="Delete image" imageSx={{ w: 'full', h: 'full' }} - withResetIcon + // withResetIcon // removed bc it's too easy to accidentally delete images isDropDisabled={true} isUploadDisabled={true} /> diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx index 23459e9410..4ca9700a8c 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -4,7 +4,12 @@ import IAIIconButton from 'common/components/IAIIconButton'; import IAISlider from 'common/components/IAISlider'; import { memo, useCallback } from 'react'; import { FaTrash } from 'react-icons/fa'; -import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice'; +import { + Lora, + loraRemoved, + loraWeightChanged, + loraWeightReset, +} from '../store/loraSlice'; type Props = { lora: Lora; @@ -22,7 +27,7 @@ const ParamLora = (props: Props) => { ); const handleReset = useCallback(() => { - dispatch(loraWeightChanged({ id: lora.id, weight: 1 })); + dispatch(loraWeightReset(lora.id)); }, [dispatch, lora.id]); const handleRemoveLora = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx index 54ac3d615d..9168814f35 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -1,4 +1,4 @@ -import { Text } from '@chakra-ui/react'; +import { Flex, Text } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -61,6 +61,16 @@ const ParamLoraSelect = () => { [dispatch, lorasQueryData?.entities] ); + if (lorasQueryData?.ids.length === 0) { + return ( + + + No LoRAs Loaded + + + ); + } + return ( = { - weight: 1, + weight: 0.75, }; export type LoraState = { @@ -38,9 +38,14 @@ export const loraSlice = createSlice({ const { id, weight } = action.payload; state.loras[id].weight = weight; }, + loraWeightReset: (state, action: PayloadAction) => { + const id = action.payload; + state.loras[id].weight = defaultLoRAConfig.weight; + }, }, }); -export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions; +export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } = + loraSlice.actions; export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx index 589b751d6b..3e5320ad47 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx @@ -1,29 +1,107 @@ -import { FormControl } from '@chakra-ui/react'; +import { Box, FormControl, useDisclosure } from '@chakra-ui/react'; import type { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAITextarea from 'common/components/IAITextarea'; +import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton'; +import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover'; import { setNegativePrompt } from 'features/parameters/store/generationSlice'; +import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react'; +import { flushSync } from 'react-dom'; import { useTranslation } from 'react-i18next'; const ParamNegativeConditioning = () => { const negativePrompt = useAppSelector( (state: RootState) => state.generation.negativePrompt ); - + const promptRef = useRef(null); + const { isOpen, onClose, onOpen } = useDisclosure(); const dispatch = useAppDispatch(); const { t } = useTranslation(); + const handleChangePrompt = useCallback( + (e: ChangeEvent) => { + dispatch(setNegativePrompt(e.target.value)); + }, + [dispatch] + ); + const handleKeyDown = useCallback( + (e: KeyboardEvent) => { + if (e.key === '<') { + onOpen(); + } + }, + [onOpen] + ); + + const handleSelectEmbedding = useCallback( + (v: string) => { + if (!promptRef.current) { + return; + } + + // this is where we insert the TI trigger + const caret = promptRef.current.selectionStart; + + if (caret === undefined) { + return; + } + + let newPrompt = negativePrompt.slice(0, caret); + + if (newPrompt[newPrompt.length - 1] !== '<') { + newPrompt += '<'; + } + + newPrompt += `${v}>`; + + // we insert the cursor after the `>` + const finalCaretPos = newPrompt.length; + + newPrompt += negativePrompt.slice(caret); + + // must flush dom updates else selection gets reset + flushSync(() => { + dispatch(setNegativePrompt(newPrompt)); + }); + + // set the caret position to just after the TI trigger promptRef.current.selectionStart = finalCaretPos; + promptRef.current.selectionEnd = finalCaretPos; + onClose(); + }, + [dispatch, onClose, negativePrompt] + ); + return ( - dispatch(setNegativePrompt(e.target.value))} - placeholder={t('parameters.negativePromptPlaceholder')} - fontSize="sm" - minH={16} - /> + + + + {!isOpen && ( + + + + )} ); }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx index f42942a84b..cbff29e89c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx @@ -1,4 +1,4 @@ -import { Box, FormControl } from '@chakra-ui/react'; +import { Box, FormControl, useDisclosure } from '@chakra-ui/react'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react'; @@ -11,12 +11,15 @@ import { } from 'features/parameters/store/generationSlice'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { isEqual } from 'lodash-es'; -import { useHotkeys } from 'react-hotkeys-hook'; -import { useTranslation } from 'react-i18next'; import { userInvoked } from 'app/store/actions'; import IAITextarea from 'common/components/IAITextarea'; import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; +import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton'; +import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover'; +import { isEqual } from 'lodash-es'; +import { flushSync } from 'react-dom'; +import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; const promptInputSelector = createSelector( [(state: RootState) => state.generation, activeTabNameSelector], @@ -40,14 +43,15 @@ const ParamPositiveConditioning = () => { const dispatch = useAppDispatch(); const { prompt, activeTabName } = useAppSelector(promptInputSelector); const isReady = useIsReadyToInvoke(); - const promptRef = useRef(null); - + const { isOpen, onClose, onOpen } = useDisclosure(); const { t } = useTranslation(); - - const handleChangePrompt = (e: ChangeEvent) => { - dispatch(setPositivePrompt(e.target.value)); - }; + const handleChangePrompt = useCallback( + (e: ChangeEvent) => { + dispatch(setPositivePrompt(e.target.value)); + }, + [dispatch] + ); useHotkeys( 'alt+a', @@ -57,6 +61,45 @@ const ParamPositiveConditioning = () => { [] ); + const handleSelectEmbedding = useCallback( + (v: string) => { + if (!promptRef.current) { + return; + } + + // this is where we insert the TI trigger + const caret = promptRef.current.selectionStart; + + if (caret === undefined) { + return; + } + + let newPrompt = prompt.slice(0, caret); + + if (newPrompt[newPrompt.length - 1] !== '<') { + newPrompt += '<'; + } + + newPrompt += `${v}>`; + + // we insert the cursor after the `>` + const finalCaretPos = newPrompt.length; + + newPrompt += prompt.slice(caret); + + // must flush dom updates else selection gets reset + flushSync(() => { + dispatch(setPositivePrompt(newPrompt)); + }); + + // set the caret position to just after the TI trigger + promptRef.current.selectionStart = finalCaretPos; + promptRef.current.selectionEnd = finalCaretPos; + onClose(); + }, + [dispatch, onClose, prompt] + ); + const handleKeyDown = useCallback( (e: KeyboardEvent) => { if (e.key === 'Enter' && e.shiftKey === false && isReady) { @@ -64,25 +107,50 @@ const ParamPositiveConditioning = () => { dispatch(clampSymmetrySteps()); dispatch(userInvoked(activeTabName)); } + if (e.key === '<') { + onOpen(); + } }, - [dispatch, activeTabName, isReady] + [isReady, dispatch, activeTabName, onOpen] ); + // const handleSelect = (e: MouseEvent) => { + // const target = e.target as HTMLTextAreaElement; + // setCaret({ start: target.selectionStart, end: target.selectionEnd }); + // }; + return ( - + + + + {!isOpen && ( + + + + )} ); }; diff --git a/invokeai/frontend/web/src/features/ui/components/PinParametersPanelButton.tsx b/invokeai/frontend/web/src/features/ui/components/PinParametersPanelButton.tsx index a742e2a587..30cc1d2158 100644 --- a/invokeai/frontend/web/src/features/ui/components/PinParametersPanelButton.tsx +++ b/invokeai/frontend/web/src/features/ui/components/PinParametersPanelButton.tsx @@ -1,4 +1,3 @@ -import { Tooltip } from '@chakra-ui/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton, { IAIIconButtonProps, @@ -25,26 +24,25 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => { }; return ( - - : } - variant="ghost" - size="sm" - sx={{ - color: 'base.700', - _hover: { - color: 'base.550', - }, - _active: { - color: 'base.500', - }, - ...sx, - }} - /> - + : } + variant="ghost" + size="sm" + sx={{ + color: 'base.700', + _hover: { + color: 'base.550', + }, + _active: { + color: 'base.500', + }, + ...sx, + }} + /> ); }; diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 38af668cac..861bf49405 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -1,10 +1,10 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; +import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { setActiveTabReducer } from './extraReducers'; import { InvokeTabName } from './tabMap'; import { AddNewModelType, UIState } from './uiTypes'; -import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; export const initialUIState: UIState = { activeTab: 0, @@ -19,6 +19,7 @@ export const initialUIState: UIState = { shouldShowGallery: true, shouldHidePreview: false, shouldShowProgressInViewer: true, + shouldShowEmbeddingPicker: false, favoriteSchedulers: [], }; @@ -96,6 +97,9 @@ export const uiSlice = createSlice({ ) => { state.favoriteSchedulers = action.payload; }, + toggleEmbeddingPicker: (state) => { + state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker; + }, }, extraReducers(builder) { builder.addCase(initialImageChanged, (state) => { @@ -122,6 +126,7 @@ export const { toggleGalleryPanel, setShouldShowProgressInViewer, favoriteSchedulersChanged, + toggleEmbeddingPicker, } = uiSlice.actions; export default uiSlice.reducer; diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index d55a1d8fcf..ad0250e56d 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -27,5 +27,6 @@ export interface UIState { shouldPinGallery: boolean; shouldShowGallery: boolean; shouldShowProgressInViewer: boolean; + shouldShowEmbeddingPicker: boolean; favoriteSchedulers: SchedulerParam[]; } diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts index 85641b88a0..665761a626 100644 --- a/invokeai/frontend/web/src/services/events/middleware.ts +++ b/invokeai/frontend/web/src/services/events/middleware.ts @@ -1,18 +1,18 @@ import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit'; -import { io, Socket } from 'socket.io-client'; +import { Socket, io } from 'socket.io-client'; +import { AppThunkDispatch, RootState } from 'app/store/store'; +import { getTimestamp } from 'common/util/getTimestamp'; +import { sessionCreated } from 'services/api/thunks/session'; import { ClientToServerEvents, ServerToClientEvents, } from 'services/events/types'; import { socketSubscribed, socketUnsubscribed } from './actions'; -import { AppThunkDispatch, RootState } from 'app/store/store'; -import { getTimestamp } from 'common/util/getTimestamp'; -import { sessionCreated } from 'services/api/thunks/session'; // import { OpenAPI } from 'services/api/types'; -import { setEventListeners } from 'services/events/util/setEventListeners'; import { log } from 'app/logging/useLogger'; import { $authToken, $baseUrl } from 'services/api/client'; +import { setEventListeners } from 'services/events/util/setEventListeners'; const socketioLog = log.child({ namespace: 'socketio' }); @@ -88,7 +88,7 @@ export const socketMiddleware = () => { socketSubscribed({ sessionId: sessionId, timestamp: getTimestamp(), - boardId: getState().boards.selectedBoardId, + boardId: getState().gallery.selectedBoardId, }) ); }