diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b764b3b336..385ddc5df8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -63,6 +63,9 @@ from .compel import ConditioningField from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField +if choose_torch_device() == torch.device("mps"): + from torch import mps + DEFAULT_PRECISION = choose_precision(choose_torch_device()) @@ -541,6 +544,8 @@ class DenoiseLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.to("cpu") torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, result_latents) @@ -612,6 +617,8 @@ class LatentsToImageInvocation(BaseInvocation): # clear memory as vae decode can request a lot torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() with torch.inference_mode(): # copied from diffusers pipeline @@ -624,6 +631,8 @@ class LatentsToImageInvocation(BaseInvocation): image = VaeImageProcessor.numpy_to_pil(np_image)[0] torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() image_dto = context.services.images.create( image=image, @@ -683,6 +692,8 @@ class ResizeLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() + if device == torch.device("mps"): + mps.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" # context.services.latents.set(name, resized_latents) @@ -719,6 +730,8 @@ class ScaleLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() + if device == torch.device("mps"): + mps.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" # context.services.latents.set(name, resized_latents) @@ -875,6 +888,8 @@ class BlendLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") torch.cuda.empty_cache() + if device == torch.device("mps"): + mps.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" # context.services.latents.set(name, resized_latents) diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 6d0f36ad8c..8c015441b7 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -29,8 +29,12 @@ import torch import invokeai.backend.util.logging as logger +from ..util.devices import choose_torch_device from .models import BaseModelType, ModelBase, ModelType, SubModelType +if choose_torch_device() == torch.device("mps"): + from torch import mps + # Maximum size of the cache, in gigs # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously DEFAULT_MAX_CACHE_SIZE = 6.0 @@ -406,6 +410,8 @@ class ModelCache(object): gc.collect() torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") @@ -426,6 +432,8 @@ class ModelCache(object): gc.collect() torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() def _local_model_hash(self, model_path: Union[str, Path]) -> str: sha = hashlib.sha256() diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index a70ed03fda..8c033440e3 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -12,29 +12,26 @@ import { languageSelector } from 'features/system/store/systemSelectors'; import InvokeTabs from 'features/ui/components/InvokeTabs'; import i18n from 'i18n'; import { size } from 'lodash-es'; -import { ReactNode, memo, useCallback, useEffect } from 'react'; +import { memo, useCallback, useEffect } from 'react'; import { ErrorBoundary } from 'react-error-boundary'; import { usePreselectedImage } from '../../features/parameters/hooks/usePreselectedImage'; import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import GlobalHotkeys from './GlobalHotkeys'; import Toaster from './Toaster'; +import { useStore } from '@nanostores/react'; +import { $headerComponent } from 'app/store/nanostores/headerComponent'; const DEFAULT_CONFIG = {}; interface Props { config?: PartialAppConfig; - headerComponent?: ReactNode; selectedImage?: { imageName: string; action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters'; }; } -const App = ({ - config = DEFAULT_CONFIG, - headerComponent, - selectedImage, -}: Props) => { +const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { const language = useAppSelector(languageSelector); const logger = useLogger('system'); @@ -65,6 +62,8 @@ const App = ({ handlePreselectedImage(selectedImage); }, [handlePreselectedImage, selectedImage]); + const headerComponent = useStore($headerComponent); + return ( import('./App')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); @@ -30,6 +32,7 @@ interface Props extends PropsWithChildren { imageName: string; action: 'sendToImg2Img' | 'sendToCanvas' | 'useAllParameters'; }; + customStarUi?: CustomStarUi; } const InvokeAIUI = ({ @@ -40,6 +43,7 @@ const InvokeAIUI = ({ middleware, projectId, selectedImage, + customStarUi, }: Props) => { useEffect(() => { // configure API client token @@ -80,17 +84,33 @@ const InvokeAIUI = ({ }; }, [apiUrl, token, middleware, projectId]); + useEffect(() => { + if (customStarUi) { + $customStarUI.set(customStarUi); + } + + return () => { + $customStarUI.set(undefined); + }; + }, [customStarUi]); + + useEffect(() => { + if (headerComponent) { + $headerComponent.set(headerComponent); + } + + return () => { + $headerComponent.set(undefined); + }; + }, [headerComponent]); + return ( }> - + diff --git a/invokeai/frontend/web/src/app/store/nanostores/customStarUI.ts b/invokeai/frontend/web/src/app/store/nanostores/customStarUI.ts new file mode 100644 index 0000000000..0459c2f31f --- /dev/null +++ b/invokeai/frontend/web/src/app/store/nanostores/customStarUI.ts @@ -0,0 +1,14 @@ +import { MenuItemProps } from '@chakra-ui/react'; +import { atom } from 'nanostores'; + +export type CustomStarUi = { + on: { + icon: MenuItemProps['icon']; + text: string; + }; + off: { + icon: MenuItemProps['icon']; + text: string; + }; +}; +export const $customStarUI = atom(undefined); diff --git a/invokeai/frontend/web/src/app/store/nanostores/headerComponent.ts b/invokeai/frontend/web/src/app/store/nanostores/headerComponent.ts new file mode 100644 index 0000000000..90a4775ff9 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/nanostores/headerComponent.ts @@ -0,0 +1,4 @@ +import { atom } from 'nanostores'; +import { ReactNode } from 'react'; + +export const $headerComponent = atom(undefined); diff --git a/invokeai/frontend/web/src/app/store/nanostores/index.ts b/invokeai/frontend/web/src/app/store/nanostores/index.ts new file mode 100644 index 0000000000..ae43ed3035 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/nanostores/index.ts @@ -0,0 +1,3 @@ +/** + * For non-serializable data that needs to be available throughout the app, or when redux is not appropriate, use nanostores. + */ diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index ce2a21c6e7..29caa69cbe 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -86,10 +86,7 @@ export const store = configureStore({ .concat(autoBatchEnhancer()); }, middleware: (getDefaultMiddleware) => - getDefaultMiddleware({ - immutableCheck: false, - serializableCheck: false, - }) + getDefaultMiddleware() .concat(api.middleware) .concat(dynamicMiddlewares) .prepend(listenerMiddleware.middleware), diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx index bf2b344b4c..29b45761ee 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx @@ -1,4 +1,6 @@ import { MenuItem } from '@chakra-ui/react'; +import { useStore } from '@nanostores/react'; +import { $customStarUI } from 'app/store/nanostores/customStarUI'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { imagesToChangeSelected, @@ -16,6 +18,7 @@ import { const MultipleSelectionMenuItems = () => { const dispatch = useAppDispatch(); const selection = useAppSelector((state) => state.gallery.selection); + const customStarUi = useStore($customStarUI); const [starImages] = useStarImagesMutation(); const [unstarImages] = useUnstarImagesMutation(); @@ -49,15 +52,18 @@ const MultipleSelectionMenuItems = () => { <> {areAllStarred && ( } + icon={customStarUi ? customStarUi.on.icon : } onClickCapture={handleUnstarSelection} > - Unstar All + {customStarUi ? customStarUi.off.text : `Unstar All`} )} {(areAllUnstarred || (!areAllStarred && !areAllUnstarred)) && ( - } onClickCapture={handleStarSelection}> - Star All + } + onClickCapture={handleStarSelection} + > + {customStarUi ? customStarUi.on.text : `Star All`} )} } onClickCapture={handleChangeBoard}> diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx index 46c84e85ce..e5b9d94578 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx @@ -1,5 +1,7 @@ import { Flex, MenuItem, Spinner } from '@chakra-ui/react'; +import { useStore } from '@nanostores/react'; import { useAppToaster } from 'app/components/Toaster'; +import { $customStarUI } from 'app/store/nanostores/customStarUI'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { @@ -7,6 +9,7 @@ import { isModalOpenChanged, } from 'features/changeBoardModal/store/slice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; +import { workflowLoadRequested } from 'features/nodes/store/actions'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { initialImageSelected } from 'features/parameters/store/actions'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; @@ -32,9 +35,8 @@ import { useUnstarImagesMutation, } from 'services/api/endpoints/images'; import { ImageDTO } from 'services/api/types'; -import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; -import { workflowLoadRequested } from 'features/nodes/store/actions'; import { configSelector } from '../../../system/store/configSelectors'; +import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions'; type SingleSelectionMenuItemsProps = { imageDTO: ImageDTO; @@ -50,6 +52,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const { shouldFetchMetadataFromApi } = useAppSelector(configSelector); + const customStarUi = useStore($customStarUI); const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery( { image: imageDTO, shouldFetchMetadataFromApi }, @@ -225,12 +228,18 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { Change Board {imageDTO.starred ? ( - } onClickCapture={handleUnstarImage}> - Unstar Image + } + onClickCapture={handleUnstarImage} + > + {customStarUi ? customStarUi.off.text : `Unstar Image`} ) : ( - } onClickCapture={handleStarImage}> - Star Image + } + onClickCapture={handleStarImage} + > + {customStarUi ? customStarUi.on.text : `Star Image`} )} { const { handleClick, isSelected, selection, selectionCount } = useMultiselect(imageDTO); + const customStarUi = useStore($customStarUI); + const handleDelete = useCallback( (e: MouseEvent) => { e.stopPropagation(); @@ -91,12 +95,22 @@ const GalleryImage = (props: HoverableImageProps) => { const starIcon = useMemo(() => { if (imageDTO?.starred) { - return ; + return customStarUi ? customStarUi.on.icon : ; } if (!imageDTO?.starred && isHovered) { - return ; + return customStarUi ? customStarUi.off.icon : ; } - }, [imageDTO?.starred, isHovered]); + }, [imageDTO?.starred, isHovered, customStarUi]); + + const starTooltip = useMemo(() => { + if (imageDTO?.starred) { + return customStarUi ? customStarUi.off.text : 'Unstar'; + } + if (!imageDTO?.starred) { + return customStarUi ? customStarUi.on.text : 'Star'; + } + return ''; + }, [imageDTO?.starred, customStarUi]); if (!imageDTO) { return ; @@ -131,7 +145,7 @@ const GalleryImage = (props: HoverableImageProps) => { {isHovered && shift && (