diff --git a/docs/nodes/communityNodes.md b/docs/nodes/communityNodes.md index 2b30b9f0af..151b9ea262 100644 --- a/docs/nodes/communityNodes.md +++ b/docs/nodes/communityNodes.md @@ -121,18 +121,6 @@ To be imported, an .obj must use triangulated meshes, so make sure to enable tha **Example Usage:** ![depth from obj usage graph](https://raw.githubusercontent.com/dwringer/depth-from-obj-node/main/depth_from_obj_usage.jpg) --------------------------------- -### Enhance Image (simple adjustments) - -**Description:** Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module. - -Color inversion is toggled with a simple switch, while each of the four enhancer modes are activated by entering a value other than 1 in each corresponding input field. Values less than 1 will reduce the corresponding property, while values greater than 1 will enhance it. - -**Node Link:** https://github.com/dwringer/image-enhance-node - -**Example Usage:** -![enhance image usage graph](https://raw.githubusercontent.com/dwringer/image-enhance-node/main/image_enhance_usage.jpg) - -------------------------------- ### Generative Grammar-Based Prompt Nodes @@ -153,16 +141,26 @@ This includes 3 Nodes: **Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling. -This includes 4 Nodes: -- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke. +This includes 14 Nodes: +- *Adjust Image Hue Plus* - Rotate the hue of an image in one of several different color spaces. +- *Blend Latents/Noise (Masked)* - Use a mask to blend part of one latents tensor [including Noise outputs] into another. Can be used to "renoise" sections during a multi-stage [masked] denoising process. +- *Enhance Image* - Boost or reduce color saturation, contrast, brightness, sharpness, or invert colors of any image at any stage with this simple wrapper for pillow [PIL]'s ImageEnhance module. +- *Equivalent Achromatic Lightness* - Calculates image lightness accounting for Helmholtz-Kohlrausch effect based on a method described by High, Green, and Nussbaum (2023). +- *Text to Mask (Clipseg)* - Input a prompt and an image to generate a mask representing areas of the image matched by the prompt. +- *Text to Mask Advanced (Clipseg)* - Output up to four prompt masks combined with logical "and", logical "or", or as separate channels of an RGBA image. +- *Image Layer Blend* - Perform a layered blend of two images using alpha compositing. Opacity of top layer is selectable, with optional mask and several different blend modes/color spaces. - *Image Compositor* - Take a subject from an image with a flat backdrop and layer it on another image using a chroma key or flood select background removal. +- *Image Dilate or Erode* - Dilate or expand a mask (or any image!). This is equivalent to an expand/contract operation. +- *Image Value Thresholds* - Clip an image to pure black/white beyond specified thresholds. - *Offset Latents* - Offset a latents tensor in the vertical and/or horizontal dimensions, wrapping it around. - *Offset Image* - Offset an image in the vertical and/or horizontal dimensions, wrapping it around. +- *Shadows/Highlights/Midtones* - Extract three masks (with adjustable hard or soft thresholds) representing shadows, midtones, and highlights regions of an image. +- *Text Mask (simple 2D)* - create and position a white on black (or black on white) line of text using any font locally available to Invoke. **Node Link:** https://github.com/dwringer/composition-nodes -**Example Usage:** -![composition nodes usage graph](https://raw.githubusercontent.com/dwringer/composition-nodes/main/composition_nodes_usage.jpg) +**Nodes and Output Examples:** +![composition nodes usage graph](https://raw.githubusercontent.com/dwringer/composition-nodes/main/composition_pack_overview.jpg) -------------------------------- ### Size Stepper Nodes diff --git a/installer/lib/installer.py b/installer/lib/installer.py index aaf5779801..70ed4d4331 100644 --- a/installer/lib/installer.py +++ b/installer/lib/installer.py @@ -332,6 +332,7 @@ class InvokeAiInstance: Configure the InvokeAI runtime directory """ + auto_install = False # set sys.argv to a consistent state new_argv = [sys.argv[0]] for i in range(1, len(sys.argv)): @@ -340,13 +341,17 @@ class InvokeAiInstance: new_argv.append(el) new_argv.append(sys.argv[i + 1]) elif el in ["-y", "--yes", "--yes-to-all"]: - new_argv.append(el) + auto_install = True sys.argv = new_argv + import messages import requests # to catch download exceptions - from messages import introduction - introduction() + auto_install = auto_install or messages.user_wants_auto_configuration() + if auto_install: + sys.argv.append("--yes") + else: + messages.introduction() from invokeai.frontend.install.invokeai_configure import invokeai_configure diff --git a/installer/lib/messages.py b/installer/lib/messages.py index c5a39dc91c..e4c03bbfd2 100644 --- a/installer/lib/messages.py +++ b/installer/lib/messages.py @@ -7,7 +7,7 @@ import os import platform from pathlib import Path -from prompt_toolkit import prompt +from prompt_toolkit import HTML, prompt from prompt_toolkit.completion import PathCompleter from prompt_toolkit.validation import Validator from rich import box, print @@ -65,17 +65,50 @@ def confirm_install(dest: Path) -> bool: if dest.exists(): print(f":exclamation: Directory {dest} already exists :exclamation:") dest_confirmed = Confirm.ask( - ":stop_sign: Are you sure you want to (re)install in this location?", + ":stop_sign: (re)install in this location?", default=False, ) else: print(f"InvokeAI will be installed in {dest}") - dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False) + dest_confirmed = Confirm.ask("Use this location?", default=True) console.line() return dest_confirmed +def user_wants_auto_configuration() -> bool: + """Prompt the user to choose between manual and auto configuration.""" + console.rule("InvokeAI Configuration Section") + console.print( + Panel( + Group( + "\n".join( + [ + "Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:", + "", + " * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.", + " * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.", + "", + "Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.", + ] + ), + ), + box=box.MINIMAL, + padding=(1, 1), + ) + ) + choice = ( + prompt( + HTML("Choose <a>utomatic or <m>anual configuration [a/m] (a): "), + validator=Validator.from_callable( + lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'" + ), + ) + or "a" + ) + return choice.lower().startswith("a") + + def dest_path(dest=None) -> Path: """ Prompt the user for the destination path and create the path diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index cb5d2d79f0..ebc40f5ce5 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -156,8 +156,6 @@ async def import_model( prediction_types = {x.value: x for x in SchedulerPredictionType} logger = ApiDependencies.invoker.services.logger - print(f"DEBUG: prediction_type = {prediction_type}") - try: installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index af7a343274..3285de3d5a 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -91,6 +91,9 @@ class FieldDescriptions: board = "The board to save the image to" image = "The image to process" tile_size = "Tile size" + inclusive_low = "The inclusive low value" + exclusive_high = "The exclusive high value" + decimal_places = "The number of decimal places to round to" class Input(str, Enum): diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 3cdd43fb59..b52cbb28bf 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -65,13 +65,27 @@ class DivideInvocation(BaseInvocation): class RandomIntInvocation(BaseInvocation): """Outputs a single random integer.""" - low: int = InputField(default=0, description="The inclusive low value") - high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") + low: int = InputField(default=0, description=FieldDescriptions.inclusive_low) + high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high) def invoke(self, context: InvocationContext) -> IntegerOutput: return IntegerOutput(value=np.random.randint(self.low, self.high)) +@invocation("rand_float", title="Random Float", tags=["math", "float", "random"], category="math", version="1.0.0") +class RandomFloatInvocation(BaseInvocation): + """Outputs a single random float""" + + low: float = InputField(default=0.0, description=FieldDescriptions.inclusive_low) + high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high) + decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places) + + def invoke(self, context: InvocationContext) -> FloatOutput: + random_float = np.random.uniform(self.low, self.high) + rounded_float = round(random_float, self.decimals) + return FloatOutput(value=rounded_float) + + @invocation( "float_to_int", title="Float To Integer", diff --git a/invokeai/app/services/config/invokeai_config.py b/invokeai/app/services/config/invokeai_config.py index 51ccf45704..8ea703f39a 100644 --- a/invokeai/app/services/config/invokeai_config.py +++ b/invokeai/app/services/config/invokeai_config.py @@ -241,7 +241,7 @@ class InvokeAIAppConfig(InvokeAISettings): version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") # CACHE - ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", ) + ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", ) vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", ) lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", ) diff --git a/invokeai/app/services/invocation_cache/invocation_cache_memory.py b/invokeai/app/services/invocation_cache/invocation_cache_memory.py index be07029f4d..817dbb958e 100644 --- a/invokeai/app/services/invocation_cache/invocation_cache_memory.py +++ b/invokeai/app/services/invocation_cache/invocation_cache_memory.py @@ -1,4 +1,6 @@ -from queue import Queue +from collections import OrderedDict +from dataclasses import dataclass, field +from threading import Lock from typing import Optional, Union from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput @@ -7,22 +9,28 @@ from invokeai.app.services.invocation_cache.invocation_cache_common import Invoc from invokeai.app.services.invoker import Invoker +@dataclass(order=True) +class CachedItem: + invocation_output: BaseInvocationOutput = field(compare=False) + invocation_output_json: str = field(compare=False) + + class MemoryInvocationCache(InvocationCacheBase): - _cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]] + _cache: OrderedDict[Union[int, str], CachedItem] _max_cache_size: int _disabled: bool _hits: int _misses: int - _cache_ids: Queue _invoker: Invoker + _lock: Lock def __init__(self, max_cache_size: int = 0) -> None: - self._cache = dict() + self._cache = OrderedDict() self._max_cache_size = max_cache_size self._disabled = False self._hits = 0 self._misses = 0 - self._cache_ids = Queue() + self._lock = Lock() def start(self, invoker: Invoker) -> None: self._invoker = invoker @@ -32,80 +40,87 @@ class MemoryInvocationCache(InvocationCacheBase): self._invoker.services.latents.on_deleted(self._delete_by_match) def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]: - if self._max_cache_size == 0 or self._disabled: - return - - item = self._cache.get(key, None) - if item is not None: - self._hits += 1 - return item[0] - self._misses += 1 + with self._lock: + if self._max_cache_size == 0 or self._disabled: + return None + item = self._cache.get(key, None) + if item is not None: + self._hits += 1 + self._cache.move_to_end(key) + return item.invocation_output + self._misses += 1 + return None def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None: - if self._max_cache_size == 0 or self._disabled: - return + with self._lock: + if self._max_cache_size == 0 or self._disabled or key in self._cache: + return + # If the cache is full, we need to remove the least used + number_to_delete = len(self._cache) + 1 - self._max_cache_size + self._delete_oldest_access(number_to_delete) + self._cache[key] = CachedItem(invocation_output, invocation_output.json()) - if key not in self._cache: - self._cache[key] = (invocation_output, invocation_output.json()) - self._cache_ids.put(key) - if self._cache_ids.qsize() > self._max_cache_size: - try: - self._cache.pop(self._cache_ids.get()) - except KeyError: - # this means the cache_ids are somehow out of sync w/ the cache - pass + def _delete_oldest_access(self, number_to_delete: int) -> None: + number_to_delete = min(number_to_delete, len(self._cache)) + for _ in range(number_to_delete): + self._cache.popitem(last=False) - def delete(self, key: Union[int, str]) -> None: + def _delete(self, key: Union[int, str]) -> None: if self._max_cache_size == 0: return - if key in self._cache: del self._cache[key] + def delete(self, key: Union[int, str]) -> None: + with self._lock: + return self._delete(key) + def clear(self, *args, **kwargs) -> None: - if self._max_cache_size == 0: - return + with self._lock: + if self._max_cache_size == 0: + return + self._cache.clear() + self._misses = 0 + self._hits = 0 - self._cache.clear() - self._cache_ids = Queue() - self._misses = 0 - self._hits = 0 - - def create_key(self, invocation: BaseInvocation) -> int: + @staticmethod + def create_key(invocation: BaseInvocation) -> int: return hash(invocation.json(exclude={"id"})) def disable(self) -> None: - if self._max_cache_size == 0: - return - self._disabled = True + with self._lock: + if self._max_cache_size == 0: + return + self._disabled = True def enable(self) -> None: - if self._max_cache_size == 0: - return - self._disabled = False + with self._lock: + if self._max_cache_size == 0: + return + self._disabled = False def get_status(self) -> InvocationCacheStatus: - return InvocationCacheStatus( - hits=self._hits, - misses=self._misses, - enabled=not self._disabled and self._max_cache_size > 0, - size=len(self._cache), - max_size=self._max_cache_size, - ) + with self._lock: + return InvocationCacheStatus( + hits=self._hits, + misses=self._misses, + enabled=not self._disabled and self._max_cache_size > 0, + size=len(self._cache), + max_size=self._max_cache_size, + ) def _delete_by_match(self, to_match: str) -> None: - if self._max_cache_size == 0: - return - - keys_to_delete = set() - for key, value_tuple in self._cache.items(): - if to_match in value_tuple[1]: - keys_to_delete.add(key) - - if not keys_to_delete: - return - - for key in keys_to_delete: - self.delete(key) - - self._invoker.services.logger.debug(f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}") + with self._lock: + if self._max_cache_size == 0: + return + keys_to_delete = set() + for key, cached_item in self._cache.items(): + if to_match in cached_item.invocation_output_json: + keys_to_delete.add(key) + if not keys_to_delete: + return + for key in keys_to_delete: + self._delete(key) + self._invoker.services.logger.debug( + f"Deleted {len(keys_to_delete)} cached invocation outputs for {to_match}" + ) diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index ec2221e12d..5afbdfb5a3 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -70,7 +70,6 @@ def get_literal_fields(field) -> list[Any]: config = InvokeAIAppConfig.get_config() Model_dir = "models" - Default_config_file = config.model_conf_path SD_Configs = config.legacy_conf_path @@ -458,7 +457,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle. ) self.add_widget_intelligent( npyscreen.TitleFixedText, - name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.", + name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).", begin_entry_at=0, editable=False, color="CONTROL", @@ -651,8 +650,19 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam return editApp.new_opts() +def default_ramcache() -> float: + """Run a heuristic for the default RAM cache based on installed RAM.""" + + # Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB, + # So we adjust everthing down a bit. + return ( + 15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1 + ) # 2.1 is just large enough for sd 1.5 ;-) + + def default_startup_options(init_file: Path) -> Namespace: opts = InvokeAIAppConfig.get_config() + opts.ram = default_ramcache() return opts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index a28ef8d490..fc9dd0cc5f 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -58,6 +58,7 @@ "githubLabel": "Github", "hotkeysLabel": "Hotkeys", "imagePrompt": "Image Prompt", + "imageFailedToLoad": "Unable to Load Image", "img2img": "Image To Image", "langArabic": "العربية", "langBrPortuguese": "Português do Brasil", @@ -79,7 +80,7 @@ "lightMode": "Light Mode", "linear": "Linear", "load": "Load", - "loading": "Loading", + "loading": "Loading $t({{noun}})...", "loadingInvokeAI": "Loading Invoke AI", "learnMore": "Learn More", "modelManager": "Model Manager", @@ -716,6 +717,7 @@ "cannotConnectInputToInput": "Cannot connect input to input", "cannotConnectOutputToOutput": "Cannot connect output to output", "cannotConnectToSelf": "Cannot connect to self", + "cannotDuplicateConnection": "Cannot create duplicate connections", "clipField": "Clip", "clipFieldDescription": "Tokenizer and text_encoder submodels.", "collection": "Collection", @@ -1442,6 +1444,8 @@ "showCanvasDebugInfo": "Show Additional Canvas Info", "showGrid": "Show Grid", "showHide": "Show/Hide", + "showResultsOn": "Show Results (On)", + "showResultsOff": "Show Results (Off)", "showIntermediates": "Show Intermediates", "snapToGrid": "Snap to Grid", "undo": "Undo" diff --git a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx index 9bcc7c831b..a9d56a7f16 100644 --- a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx +++ b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx @@ -5,7 +5,7 @@ import { } from '@chakra-ui/react'; import { ReactNode, memo, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { theme as invokeAITheme } from 'theme/theme'; +import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme'; import '@fontsource-variable/inter'; import { MantineProvider } from '@mantine/core'; @@ -39,7 +39,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) { return ( - + {children} diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index ead6e1cd42..677b0fd20c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -54,21 +54,6 @@ import { addModelSelectedListener } from './listeners/modelSelected'; import { addModelsLoadedListener } from './listeners/modelsLoaded'; import { addDynamicPromptsListener } from './listeners/promptChanged'; import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema'; -import { - addSessionCanceledFulfilledListener, - addSessionCanceledPendingListener, - addSessionCanceledRejectedListener, -} from './listeners/sessionCanceled'; -import { - addSessionCreatedFulfilledListener, - addSessionCreatedPendingListener, - addSessionCreatedRejectedListener, -} from './listeners/sessionCreated'; -import { - addSessionInvokedFulfilledListener, - addSessionInvokedPendingListener, - addSessionInvokedRejectedListener, -} from './listeners/sessionInvoked'; import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected'; import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress'; @@ -86,6 +71,7 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa import { addTabChangedListener } from './listeners/tabChanged'; import { addUpscaleRequestedListener } from './listeners/upscaleRequested'; import { addWorkflowLoadedListener } from './listeners/workflowLoaded'; +import { addBatchEnqueuedListener } from './listeners/batchEnqueued'; export const listenerMiddleware = createListenerMiddleware(); @@ -136,6 +122,7 @@ addEnqueueRequestedCanvasListener(); addEnqueueRequestedNodes(); addEnqueueRequestedLinear(); addAnyEnqueuedListener(); +addBatchEnqueuedListener(); // Canvas actions addCanvasSavedToGalleryListener(); @@ -175,21 +162,6 @@ addSessionRetrievalErrorEventListener(); addInvocationRetrievalErrorEventListener(); addSocketQueueItemStatusChangedEventListener(); -// Session Created -addSessionCreatedPendingListener(); -addSessionCreatedFulfilledListener(); -addSessionCreatedRejectedListener(); - -// Session Invoked -addSessionInvokedPendingListener(); -addSessionInvokedFulfilledListener(); -addSessionInvokedRejectedListener(); - -// Session Canceled -addSessionCanceledPendingListener(); -addSessionCanceledFulfilledListener(); -addSessionCanceledRejectedListener(); - // ControlNet addControlNetImageProcessedListener(); addControlNetAutoProcessListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts new file mode 100644 index 0000000000..fe351f3be6 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/batchEnqueued.ts @@ -0,0 +1,96 @@ +import { createStandaloneToast } from '@chakra-ui/react'; +import { logger } from 'app/logging/logger'; +import { parseify } from 'common/util/serialize'; +import { zPydanticValidationError } from 'features/system/store/zodSchemas'; +import { t } from 'i18next'; +import { get, truncate, upperFirst } from 'lodash-es'; +import { queueApi } from 'services/api/endpoints/queue'; +import { TOAST_OPTIONS, theme } from 'theme/theme'; +import { startAppListening } from '..'; + +const { toast } = createStandaloneToast({ + theme: theme, + defaultOptions: TOAST_OPTIONS.defaultOptions, +}); + +export const addBatchEnqueuedListener = () => { + // success + startAppListening({ + matcher: queueApi.endpoints.enqueueBatch.matchFulfilled, + effect: async (action) => { + const response = action.payload; + const arg = action.meta.arg.originalArgs; + logger('queue').debug( + { enqueueResult: parseify(response) }, + 'Batch enqueued' + ); + + if (!toast.isActive('batch-queued')) { + toast({ + id: 'batch-queued', + title: t('queue.batchQueued'), + description: t('queue.batchQueuedDesc', { + item_count: response.enqueued, + direction: arg.prepend ? t('queue.front') : t('queue.back'), + }), + duration: 1000, + status: 'success', + }); + } + }, + }); + + // error + startAppListening({ + matcher: queueApi.endpoints.enqueueBatch.matchRejected, + effect: async (action) => { + const response = action.payload; + const arg = action.meta.arg.originalArgs; + + if (!response) { + toast({ + title: t('queue.batchFailedToQueue'), + status: 'error', + description: 'Unknown Error', + }); + logger('queue').error( + { batchConfig: parseify(arg), error: parseify(response) }, + t('queue.batchFailedToQueue') + ); + return; + } + + const result = zPydanticValidationError.safeParse(response); + if (result.success) { + result.data.data.detail.map((e) => { + toast({ + id: 'batch-failed-to-queue', + title: truncate(upperFirst(e.msg), { length: 128 }), + status: 'error', + description: truncate( + `Path: + ${e.loc.join('.')}`, + { length: 128 } + ), + }); + }); + } else { + let detail = 'Unknown Error'; + if (response.status === 403 && 'body' in response) { + detail = get(response, 'body.detail', 'Unknown Error'); + } else if (response.status === 403 && 'error' in response) { + detail = get(response, 'error.detail', 'Unknown Error'); + } + toast({ + title: t('queue.batchFailedToQueue'), + status: 'error', + description: detail, + }); + } + logger('queue').error( + { batchConfig: parseify(arg), error: parseify(response) }, + t('queue.batchFailedToQueue') + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts index 1b13181911..a8e1a04fc1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts @@ -25,7 +25,7 @@ export const addBoardIdSelectedListener = () => { const state = getState(); const board_id = boardIdSelected.match(action) - ? action.payload + ? action.payload.boardId : state.gallery.selectedBoardId; const galleryView = galleryViewChanged.match(action) @@ -55,7 +55,12 @@ export const addBoardIdSelectedListener = () => { if (boardImagesData) { const firstImage = imagesSelectors.selectAll(boardImagesData)[0]; - dispatch(imageSelected(firstImage ?? null)); + const selectedImage = imagesSelectors.selectById( + boardImagesData, + action.payload.selectedImageName + ); + + dispatch(imageSelected(selectedImage || firstImage || null)); } else { // board has no images - deselect dispatch(imageSelected(null)); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts index 835b8246f1..9389b0f373 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasImageToControlNet.ts @@ -3,9 +3,9 @@ import { canvasImageToControlNet } from 'features/canvas/store/actions'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { addToast } from 'features/system/store/systemSlice'; +import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { startAppListening } from '..'; -import { t } from 'i18next'; export const addCanvasImageToControlNetListener = () => { startAppListening({ @@ -16,7 +16,7 @@ export const addCanvasImageToControlNetListener = () => { let blob; try { - blob = await getBaseLayerBlob(state); + blob = await getBaseLayerBlob(state, true); } catch (err) { log.error(String(err)); dispatch( @@ -36,10 +36,10 @@ export const addCanvasImageToControlNetListener = () => { file: new File([blob], 'savedCanvas.png', { type: 'image/png', }), - image_category: 'mask', + image_category: 'control', is_intermediate: false, board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, - crop_visible: true, + crop_visible: false, postUploadAction: { type: 'TOAST', toastOptions: { title: t('toast.canvasSentControlnetAssets') }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts index 671c7f63e4..2c5c26e830 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMaskToControlNet.ts @@ -3,9 +3,9 @@ import { canvasMaskToControlNet } from 'features/canvas/store/actions'; import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { addToast } from 'features/system/store/systemSlice'; +import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { startAppListening } from '..'; -import { t } from 'i18next'; export const addCanvasMaskToControlNetListener = () => { startAppListening({ @@ -50,7 +50,7 @@ export const addCanvasMaskToControlNetListener = () => { image_category: 'mask', is_intermediate: false, board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId, - crop_visible: true, + crop_visible: false, postUploadAction: { type: 'TOAST', toastOptions: { title: t('toast.maskSentControlnetAssets') }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts index c1511bd0e8..8c283ce64e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts @@ -12,8 +12,6 @@ import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGeneratio import { canvasGraphBuilt } from 'features/nodes/store/actions'; import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig'; -import { addToast } from 'features/system/store/systemSlice'; -import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; import { ImageDTO } from 'services/api/types'; @@ -140,8 +138,6 @@ export const addEnqueueRequestedCanvasListener = () => { const enqueueResult = await req.unwrap(); req.reset(); - log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued'); - const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it // Prep the canvas staging area if it is not yet initialized @@ -158,28 +154,8 @@ export const addEnqueueRequestedCanvasListener = () => { // Associate the session with the canvas session ID dispatch(canvasBatchIdAdded(batchId)); - - dispatch( - addToast({ - title: t('queue.batchQueued'), - description: t('queue.batchQueuedDesc', { - item_count: enqueueResult.enqueued, - direction: prepend ? t('queue.front') : t('queue.back'), - }), - status: 'success', - }) - ); } catch { - log.error( - { batchConfig: parseify(batchConfig) }, - t('queue.batchFailedToQueue') - ); - dispatch( - addToast({ - title: t('queue.batchFailedToQueue'), - status: 'error', - }) - ); + // no-op } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index e36c6f2ebe..bb89d18b91 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -1,13 +1,9 @@ -import { logger } from 'app/logging/logger'; import { enqueueRequested } from 'app/store/actions'; -import { parseify } from 'common/util/serialize'; import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig'; import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph'; import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph'; import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph'; import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph'; -import { addToast } from 'features/system/store/systemSlice'; -import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; import { startAppListening } from '..'; @@ -18,7 +14,6 @@ export const addEnqueueRequestedLinear = () => { (action.payload.tabName === 'txt2img' || action.payload.tabName === 'img2img'), effect: async (action, { getState, dispatch }) => { - const log = logger('queue'); const state = getState(); const model = state.generation.model; const { prepend } = action.payload; @@ -41,38 +36,12 @@ export const addEnqueueRequestedLinear = () => { const batchConfig = prepareLinearUIBatch(state, graph, prepend); - try { - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { - fixedCacheKey: 'enqueueBatch', - }) - ); - const enqueueResult = await req.unwrap(); - req.reset(); - - log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued'); - dispatch( - addToast({ - title: t('queue.batchQueued'), - description: t('queue.batchQueuedDesc', { - item_count: enqueueResult.enqueued, - direction: prepend ? t('queue.front') : t('queue.back'), - }), - status: 'success', - }) - ); - } catch { - log.error( - { batchConfig: parseify(batchConfig) }, - t('queue.batchFailedToQueue') - ); - dispatch( - addToast({ - title: t('queue.batchFailedToQueue'), - status: 'error', - }) - ); - } + const req = dispatch( + queueApi.endpoints.enqueueBatch.initiate(batchConfig, { + fixedCacheKey: 'enqueueBatch', + }) + ); + req.reset(); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts index 31281678d4..b87e443a4e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts @@ -1,9 +1,5 @@ -import { logger } from 'app/logging/logger'; import { enqueueRequested } from 'app/store/actions'; -import { parseify } from 'common/util/serialize'; import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph'; -import { addToast } from 'features/system/store/systemSlice'; -import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; import { BatchConfig } from 'services/api/types'; import { startAppListening } from '..'; @@ -13,9 +9,7 @@ export const addEnqueueRequestedNodes = () => { predicate: (action): action is ReturnType => enqueueRequested.match(action) && action.payload.tabName === 'nodes', effect: async (action, { getState, dispatch }) => { - const log = logger('queue'); const state = getState(); - const { prepend } = action.payload; const graph = buildNodesGraph(state.nodes); const batchConfig: BatchConfig = { batch: { @@ -25,38 +19,12 @@ export const addEnqueueRequestedNodes = () => { prepend: action.payload.prepend, }; - try { - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { - fixedCacheKey: 'enqueueBatch', - }) - ); - const enqueueResult = await req.unwrap(); - req.reset(); - - log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued'); - dispatch( - addToast({ - title: t('queue.batchQueued'), - description: t('queue.batchQueuedDesc', { - item_count: enqueueResult.enqueued, - direction: prepend ? t('queue.front') : t('queue.back'), - }), - status: 'success', - }) - ); - } catch { - log.error( - { batchConfig: parseify(batchConfig) }, - 'Failed to enqueue batch' - ); - dispatch( - addToast({ - title: t('queue.batchFailedToQueue'), - status: 'error', - }) - ); - } + const req = dispatch( + queueApi.endpoints.enqueueBatch.initiate(batchConfig, { + fixedCacheKey: 'enqueueBatch', + }) + ); + req.reset(); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts deleted file mode 100644 index 2592437348..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCanceled.ts +++ /dev/null @@ -1,44 +0,0 @@ -import { logger } from 'app/logging/logger'; -import { serializeError } from 'serialize-error'; -import { sessionCanceled } from 'services/api/thunks/session'; -import { startAppListening } from '..'; - -export const addSessionCanceledPendingListener = () => { - startAppListening({ - actionCreator: sessionCanceled.pending, - effect: () => { - // - }, - }); -}; - -export const addSessionCanceledFulfilledListener = () => { - startAppListening({ - actionCreator: sessionCanceled.fulfilled, - effect: (action) => { - const log = logger('session'); - const { session_id } = action.meta.arg; - log.debug({ session_id }, `Session canceled (${session_id})`); - }, - }); -}; - -export const addSessionCanceledRejectedListener = () => { - startAppListening({ - actionCreator: sessionCanceled.rejected, - effect: (action) => { - const log = logger('session'); - const { session_id } = action.meta.arg; - if (action.payload) { - const { error } = action.payload; - log.error( - { - session_id, - error: serializeError(error), - }, - `Problem canceling session` - ); - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts deleted file mode 100644 index e89acb7542..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionCreated.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { logger } from 'app/logging/logger'; -import { parseify } from 'common/util/serialize'; -import { serializeError } from 'serialize-error'; -import { sessionCreated } from 'services/api/thunks/session'; -import { startAppListening } from '..'; - -export const addSessionCreatedPendingListener = () => { - startAppListening({ - actionCreator: sessionCreated.pending, - effect: () => { - // - }, - }); -}; - -export const addSessionCreatedFulfilledListener = () => { - startAppListening({ - actionCreator: sessionCreated.fulfilled, - effect: (action) => { - const log = logger('session'); - const session = action.payload; - log.debug( - { session: parseify(session) }, - `Session created (${session.id})` - ); - }, - }); -}; - -export const addSessionCreatedRejectedListener = () => { - startAppListening({ - actionCreator: sessionCreated.rejected, - effect: (action) => { - const log = logger('session'); - if (action.payload) { - const { error, status } = action.payload; - const graph = parseify(action.meta.arg); - log.error( - { graph, status, error: serializeError(error) }, - `Problem creating session` - ); - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts deleted file mode 100644 index a62f75d957..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/sessionInvoked.ts +++ /dev/null @@ -1,44 +0,0 @@ -import { logger } from 'app/logging/logger'; -import { serializeError } from 'serialize-error'; -import { sessionInvoked } from 'services/api/thunks/session'; -import { startAppListening } from '..'; - -export const addSessionInvokedPendingListener = () => { - startAppListening({ - actionCreator: sessionInvoked.pending, - effect: () => { - // - }, - }); -}; - -export const addSessionInvokedFulfilledListener = () => { - startAppListening({ - actionCreator: sessionInvoked.fulfilled, - effect: (action) => { - const log = logger('session'); - const { session_id } = action.meta.arg; - log.debug({ session_id }, `Session invoked (${session_id})`); - }, - }); -}; - -export const addSessionInvokedRejectedListener = () => { - startAppListening({ - actionCreator: sessionInvoked.rejected, - effect: (action) => { - const log = logger('session'); - const { session_id } = action.meta.arg; - if (action.payload) { - const { error } = action.payload; - log.error( - { - session_id, - error: serializeError(error), - }, - `Problem invoking session` - ); - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index 7e918410a7..beaa4835b3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -81,9 +81,32 @@ export const addInvocationCompleteEventListener = () => { // If auto-switch is enabled, select the new image if (shouldAutoSwitch) { - // if auto-add is enabled, switch the board as the image comes in - dispatch(galleryViewChanged('images')); - dispatch(boardIdSelected(imageDTO.board_id ?? 'none')); + // if auto-add is enabled, switch the gallery view and board if needed as the image comes in + if (gallery.galleryView !== 'images') { + dispatch(galleryViewChanged('images')); + } + + if ( + imageDTO.board_id && + imageDTO.board_id !== gallery.selectedBoardId + ) { + dispatch( + boardIdSelected({ + boardId: imageDTO.board_id, + selectedImageName: imageDTO.image_name, + }) + ); + } + + if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') { + dispatch( + boardIdSelected({ + boardId: 'none', + selectedImageName: imageDTO.image_name, + }) + ); + } + dispatch(imageSelected(imageDTO)); } } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts index b0377e950b..4af35dbe9c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts @@ -35,6 +35,7 @@ export const addSocketQueueItemStatusChangedEventListener = () => { queueApi.util.invalidateTags([ 'CurrentSessionQueueItem', 'NextSessionQueueItem', + 'InvocationCacheStatus', { type: 'SessionQueueItem', id: item_id }, { type: 'SessionQueueItemDTO', id: item_id }, { type: 'BatchStatus', id: queue_batch_id }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/util/enqueueBatch.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/util/enqueueBatch.ts deleted file mode 100644 index 1d5a1232c8..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/util/enqueueBatch.ts +++ /dev/null @@ -1,54 +0,0 @@ -import { logger } from 'app/logging/logger'; -import { AppThunkDispatch } from 'app/store/store'; -import { parseify } from 'common/util/serialize'; -import { addToast } from 'features/system/store/systemSlice'; -import { t } from 'i18next'; -import { queueApi } from 'services/api/endpoints/queue'; -import { BatchConfig } from 'services/api/types'; - -export const enqueueBatch = async ( - batchConfig: BatchConfig, - dispatch: AppThunkDispatch -) => { - const log = logger('session'); - const { prepend } = batchConfig; - - try { - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { - fixedCacheKey: 'enqueueBatch', - }) - ); - const enqueueResult = await req.unwrap(); - req.reset(); - - dispatch( - queueApi.endpoints.resumeProcessor.initiate(undefined, { - fixedCacheKey: 'resumeProcessor', - }) - ); - - log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued'); - dispatch( - addToast({ - title: t('queue.batchQueued'), - description: t('queue.batchQueuedDesc', { - item_count: enqueueResult.enqueued, - direction: prepend ? t('queue.front') : t('queue.back'), - }), - status: 'success', - }) - ); - } catch { - log.error( - { batchConfig: parseify(batchConfig) }, - t('queue.batchFailedToQueue') - ); - dispatch( - addToast({ - title: t('queue.batchFailedToQueue'), - status: 'error', - }) - ); - } -}; diff --git a/invokeai/frontend/web/src/common/components/IAIColorPicker.tsx b/invokeai/frontend/web/src/common/components/IAIColorPicker.tsx index f6a05c86b1..5854f7503f 100644 --- a/invokeai/frontend/web/src/common/components/IAIColorPicker.tsx +++ b/invokeai/frontend/web/src/common/components/IAIColorPicker.tsx @@ -1,18 +1,9 @@ -import { chakra, ChakraProps } from '@chakra-ui/react'; +import { Box, ChakraProps } from '@chakra-ui/react'; import { memo } from 'react'; import { RgbaColorPicker } from 'react-colorful'; import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types'; -type IAIColorPickerProps = Omit, 'color'> & - ChakraProps & { - pickerColor: RgbaColor; - styleClass?: string; - }; - -const ChakraRgbaColorPicker = chakra(RgbaColorPicker, { - baseStyle: { paddingInline: 4 }, - shouldForwardProp: (prop) => !['pickerColor'].includes(prop), -}); +type IAIColorPickerProps = ColorPickerBaseProps; const colorPickerStyles: NonNullable = { width: 6, @@ -20,19 +11,17 @@ const colorPickerStyles: NonNullable = { borderColor: 'base.100', }; -const IAIColorPicker = (props: IAIColorPickerProps) => { - const { styleClass = '', ...rest } = props; +const sx = { + '.react-colorful__hue-pointer': colorPickerStyles, + '.react-colorful__saturation-pointer': colorPickerStyles, + '.react-colorful__alpha-pointer': colorPickerStyles, +}; +const IAIColorPicker = (props: IAIColorPickerProps) => { return ( - + + + ); }; diff --git a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx index ca61ea847f..3c1a05d527 100644 --- a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx +++ b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx @@ -81,3 +81,38 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => { ); }; + +type IAINoImageFallbackWithSpinnerProps = FlexProps & { + label?: string; +}; + +export const IAINoContentFallbackWithSpinner = ( + props: IAINoImageFallbackWithSpinnerProps +) => { + const { sx, ...rest } = props; + + return ( + + + {props.label && {props.label}} + + ); +}; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvas.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvas.tsx index e2f20f99a2..360d764a6e 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvas.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvas.tsx @@ -139,6 +139,11 @@ const IAICanvas = () => { const { handleDragStart, handleDragMove, handleDragEnd } = useCanvasDragMove(); + const handleContextMenu = useCallback( + (e: KonvaEventObject) => e.evt.preventDefault(), + [] + ); + useEffect(() => { if (!containerRef.current) { return; @@ -205,9 +210,7 @@ const IAICanvas = () => { onDragStart={handleDragStart} onDragMove={handleDragMove} onDragEnd={handleDragEnd} - onContextMenu={(e: KonvaEventObject) => - e.evt.preventDefault() - } + onContextMenu={handleContextMenu} onWheel={handleWheel} draggable={(tool === 'move' || isStaging) && !isModifyingBoundingBox} > @@ -223,7 +226,11 @@ const IAICanvas = () => { > - + diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx index 9f8829c280..d87d912a1e 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImage.tsx @@ -1,26 +1,27 @@ import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { Image, Rect } from 'react-konva'; +import { memo } from 'react'; +import { Image } from 'react-konva'; +import { $authToken } from 'services/api/client'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import useImage from 'use-image'; import { CanvasImage } from '../store/canvasTypes'; -import { $authToken } from 'services/api/client'; -import { memo } from 'react'; +import IAICanvasImageErrorFallback from './IAICanvasImageErrorFallback'; type IAICanvasImageProps = { canvasImage: CanvasImage; }; const IAICanvasImage = (props: IAICanvasImageProps) => { - const { width, height, x, y, imageName } = props.canvasImage; + const { x, y, imageName } = props.canvasImage; const { currentData: imageDTO, isError } = useGetImageDTOQuery( imageName ?? skipToken ); - const [image] = useImage( + const [image, status] = useImage( imageDTO?.image_url ?? '', $authToken.get() ? 'use-credentials' : 'anonymous' ); - if (isError) { - return ; + if (isError || status === 'failed') { + return ; } return ; diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasImageErrorFallback.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImageErrorFallback.tsx new file mode 100644 index 0000000000..38322daafa --- /dev/null +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasImageErrorFallback.tsx @@ -0,0 +1,44 @@ +import { useColorModeValue, useToken } from '@chakra-ui/react'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { Group, Rect, Text } from 'react-konva'; +import { CanvasImage } from '../store/canvasTypes'; + +type IAICanvasImageErrorFallbackProps = { + canvasImage: CanvasImage; +}; +const IAICanvasImageErrorFallback = ({ + canvasImage, +}: IAICanvasImageErrorFallbackProps) => { + const [errorColorLight, errorColorDark, fontColorLight, fontColorDark] = + useToken('colors', ['base.400', 'base.500', 'base.700', 'base.900']); + const errorColor = useColorModeValue(errorColorLight, errorColorDark); + const fontColor = useColorModeValue(fontColorLight, fontColorDark); + const { t } = useTranslation(); + return ( + + + + + ); +}; + +export default memo(IAICanvasImageErrorFallback); diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingArea.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingArea.tsx index fa73f020da..4585ab76af 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingArea.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingArea.tsx @@ -3,10 +3,9 @@ import { useAppSelector } from 'app/store/storeHooks'; import { canvasSelector } from 'features/canvas/store/canvasSelectors'; import { GroupConfig } from 'konva/lib/Group'; import { isEqual } from 'lodash-es'; - +import { memo } from 'react'; import { Group, Rect } from 'react-konva'; import IAICanvasImage from './IAICanvasImage'; -import { memo } from 'react'; const selector = createSelector( [canvasSelector], @@ -15,11 +14,11 @@ const selector = createSelector( layerState, shouldShowStagingImage, shouldShowStagingOutline, - boundingBoxCoordinates: { x, y }, - boundingBoxDimensions: { width, height }, + boundingBoxCoordinates: stageBoundingBoxCoordinates, + boundingBoxDimensions: stageBoundingBoxDimensions, } = canvas; - const { selectedImageIndex, images } = layerState.stagingArea; + const { selectedImageIndex, images, boundingBox } = layerState.stagingArea; return { currentStagingAreaImage: @@ -30,10 +29,10 @@ const selector = createSelector( isOnLastImage: selectedImageIndex === images.length - 1, shouldShowStagingImage, shouldShowStagingOutline, - x, - y, - width, - height, + x: boundingBox?.x ?? stageBoundingBoxCoordinates.x, + y: boundingBox?.y ?? stageBoundingBoxCoordinates.y, + width: boundingBox?.width ?? stageBoundingBoxDimensions.width, + height: boundingBox?.height ?? stageBoundingBoxDimensions.height, }; }, { diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx index 3e617f8767..8bb45840d0 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasStagingAreaToolbar.tsx @@ -14,6 +14,7 @@ import { import { skipToken } from '@reduxjs/toolkit/dist/query'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIButton from 'common/components/IAIButton'; import { memo, useCallback } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; @@ -23,8 +24,8 @@ import { FaCheck, FaEye, FaEyeSlash, - FaPlus, FaSave, + FaTimes, } from 'react-icons/fa'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { stagingAreaImageSaved } from '../store/actions'; @@ -41,10 +42,10 @@ const selector = createSelector( } = canvas; return { + currentIndex: selectedImageIndex, + total: images.length, currentStagingAreaImage: images.length > 0 ? images[selectedImageIndex] : undefined, - isOnFirstImage: selectedImageIndex === 0, - isOnLastImage: selectedImageIndex === images.length - 1, shouldShowStagingImage, shouldShowStagingOutline, }; @@ -55,10 +56,10 @@ const selector = createSelector( const IAICanvasStagingAreaToolbar = () => { const dispatch = useAppDispatch(); const { - isOnFirstImage, - isOnLastImage, currentStagingAreaImage, shouldShowStagingImage, + currentIndex, + total, } = useAppSelector(selector); const { t } = useTranslation(); @@ -71,39 +72,6 @@ const IAICanvasStagingAreaToolbar = () => { dispatch(setShouldShowStagingOutline(false)); }, [dispatch]); - useHotkeys( - ['left'], - () => { - handlePrevImage(); - }, - { - enabled: () => true, - preventDefault: true, - } - ); - - useHotkeys( - ['right'], - () => { - handleNextImage(); - }, - { - enabled: () => true, - preventDefault: true, - } - ); - - useHotkeys( - ['enter'], - () => { - handleAccept(); - }, - { - enabled: () => true, - preventDefault: true, - } - ); - const handlePrevImage = useCallback( () => dispatch(prevStagingAreaImage()), [dispatch] @@ -119,10 +87,45 @@ const IAICanvasStagingAreaToolbar = () => { [dispatch] ); + useHotkeys(['left'], handlePrevImage, { + enabled: () => true, + preventDefault: true, + }); + + useHotkeys(['right'], handleNextImage, { + enabled: () => true, + preventDefault: true, + }); + + useHotkeys(['enter'], () => handleAccept, { + enabled: () => true, + preventDefault: true, + }); + const { data: imageDTO } = useGetImageDTOQuery( currentStagingAreaImage?.imageName ?? skipToken ); + const handleToggleShouldShowStagingImage = useCallback(() => { + dispatch(setShouldShowStagingImage(!shouldShowStagingImage)); + }, [dispatch, shouldShowStagingImage]); + + const handleSaveToGallery = useCallback(() => { + if (!imageDTO) { + return; + } + + dispatch( + stagingAreaImageSaved({ + imageDTO, + }) + ); + }, [dispatch, imageDTO]); + + const handleDiscardStagingArea = useCallback(() => { + dispatch(discardStagedImages()); + }, [dispatch]); + if (!currentStagingAreaImage) { return null; } @@ -131,11 +134,12 @@ const IAICanvasStagingAreaToolbar = () => { { icon={} onClick={handlePrevImage} colorScheme="accent" - isDisabled={isOnFirstImage} + isDisabled={!shouldShowStagingImage} /> + {`${currentIndex + 1}/${total}`} } onClick={handleNextImage} colorScheme="accent" - isDisabled={isOnLastImage} + isDisabled={!shouldShowStagingImage} /> + + { colorScheme="accent" /> : } - onClick={() => - dispatch(setShouldShowStagingImage(!shouldShowStagingImage)) - } + onClick={handleToggleShouldShowStagingImage} colorScheme="accent" /> { aria-label={t('unifiedCanvas.saveToGallery')} isDisabled={!imageDTO || !imageDTO.is_intermediate} icon={} - onClick={() => { - if (!imageDTO) { - return; - } - - dispatch( - stagingAreaImageSaved({ - imageDTO, - }) - ); - }} + onClick={handleSaveToGallery} colorScheme="accent" /> } - onClick={() => dispatch(discardStagedImages())} + icon={} + onClick={handleDiscardStagingArea} colorScheme="error" fontSize={20} /> diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasBoundingBox.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasBoundingBox.tsx index 0f94b1c57a..8f86605726 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasBoundingBox.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasBoundingBox.tsx @@ -213,45 +213,45 @@ const IAICanvasBoundingBox = (props: IAICanvasBoundingBoxPreviewProps) => { [scaledStep] ); - const handleStartedTransforming = () => { + const handleStartedTransforming = useCallback(() => { dispatch(setIsTransformingBoundingBox(true)); - }; + }, [dispatch]); - const handleEndedTransforming = () => { + const handleEndedTransforming = useCallback(() => { dispatch(setIsTransformingBoundingBox(false)); dispatch(setIsMovingBoundingBox(false)); dispatch(setIsMouseOverBoundingBox(false)); setIsMouseOverBoundingBoxOutline(false); - }; + }, [dispatch]); - const handleStartedMoving = () => { + const handleStartedMoving = useCallback(() => { dispatch(setIsMovingBoundingBox(true)); - }; + }, [dispatch]); - const handleEndedModifying = () => { + const handleEndedModifying = useCallback(() => { dispatch(setIsTransformingBoundingBox(false)); dispatch(setIsMovingBoundingBox(false)); dispatch(setIsMouseOverBoundingBox(false)); setIsMouseOverBoundingBoxOutline(false); - }; + }, [dispatch]); - const handleMouseOver = () => { + const handleMouseOver = useCallback(() => { setIsMouseOverBoundingBoxOutline(true); - }; + }, []); - const handleMouseOut = () => { + const handleMouseOut = useCallback(() => { !isTransformingBoundingBox && !isMovingBoundingBox && setIsMouseOverBoundingBoxOutline(false); - }; + }, [isMovingBoundingBox, isTransformingBoundingBox]); - const handleMouseEnterBoundingBox = () => { + const handleMouseEnterBoundingBox = useCallback(() => { dispatch(setIsMouseOverBoundingBox(true)); - }; + }, [dispatch]); - const handleMouseLeaveBoundingBox = () => { + const handleMouseLeaveBoundingBox = useCallback(() => { dispatch(setIsMouseOverBoundingBox(false)); - }; + }, [dispatch]); return ( diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx index 76211a2e95..43e8febd66 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasMaskOptions.tsx @@ -1,4 +1,4 @@ -import { ButtonGroup, Flex } from '@chakra-ui/react'; +import { Box, ButtonGroup, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIButton from 'common/components/IAIButton'; @@ -135,11 +135,12 @@ const IAICanvasMaskOptions = () => { dispatch(setShouldPreserveMaskedArea(e.target.checked)) } /> - dispatch(setMaskColor(newColor))} - /> + + dispatch(setMaskColor(newColor))} + /> + } onClick={handleSaveMask}> Save Mask diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolChooserOptions.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolChooserOptions.tsx index 6a7db0e5f2..b5770fdda6 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolChooserOptions.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasToolChooserOptions.tsx @@ -1,4 +1,4 @@ -import { ButtonGroup, Flex } from '@chakra-ui/react'; +import { ButtonGroup, Flex, Box } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -237,15 +237,18 @@ const IAICanvasToolChooserOptions = () => { sliderNumberInputProps={{ max: 500 }} /> - dispatch(setBrushColor(newColor))} - /> + > + dispatch(setBrushColor(newColor))} + /> + diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSelectors.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSelectors.ts index 46bf7db3d0..8f1e246aaa 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSelectors.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSelectors.ts @@ -6,7 +6,7 @@ export const canvasSelector = (state: RootState): CanvasState => state.canvas; export const isStagingSelector = createSelector( [stateSelector], - ({ canvas }) => canvas.layerState.stagingArea.images.length > 0 + ({ canvas }) => canvas.batchIds.length > 0 ); export const initialCanvasImageSelector = ( diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index b726e757f6..df601e9e67 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -8,7 +8,6 @@ import { setAspectRatio } from 'features/parameters/store/generationSlice'; import { IRect, Vector2d } from 'konva/lib/types'; import { clamp, cloneDeep } from 'lodash-es'; import { RgbaColor } from 'react-colorful'; -import { sessionCanceled } from 'services/api/thunks/session'; import { ImageDTO } from 'services/api/types'; import calculateCoordinates from '../util/calculateCoordinates'; import calculateScale from '../util/calculateScale'; @@ -187,7 +186,7 @@ export const canvasSlice = createSlice({ state.pastLayerStates.push(cloneDeep(state.layerState)); state.layerState = { - ...initialLayerState, + ...cloneDeep(initialLayerState), objects: [ { kind: 'image', @@ -201,6 +200,7 @@ export const canvasSlice = createSlice({ ], }; state.futureLayerStates = []; + state.batchIds = []; const newScale = calculateScale( stageDimensions.width, @@ -350,11 +350,14 @@ export const canvasSlice = createSlice({ state.pastLayerStates.shift(); } - state.layerState.stagingArea = { ...initialLayerState.stagingArea }; + state.layerState.stagingArea = cloneDeep( + cloneDeep(initialLayerState) + ).stagingArea; state.futureLayerStates = []; state.shouldShowStagingOutline = true; - state.shouldShowStagingOutline = true; + state.shouldShowStagingImage = true; + state.batchIds = []; }, addFillRect: (state) => { const { boundingBoxCoordinates, boundingBoxDimensions, brushColor } = @@ -491,8 +494,9 @@ export const canvasSlice = createSlice({ resetCanvas: (state) => { state.pastLayerStates.push(cloneDeep(state.layerState)); - state.layerState = initialLayerState; + state.layerState = cloneDeep(initialLayerState); state.futureLayerStates = []; + state.batchIds = []; }, canvasResized: ( state, @@ -617,25 +621,22 @@ export const canvasSlice = createSlice({ return; } - const currentIndex = state.layerState.stagingArea.selectedImageIndex; - const length = state.layerState.stagingArea.images.length; + const nextIndex = state.layerState.stagingArea.selectedImageIndex + 1; + const lastIndex = state.layerState.stagingArea.images.length - 1; - state.layerState.stagingArea.selectedImageIndex = Math.min( - currentIndex + 1, - length - 1 - ); + state.layerState.stagingArea.selectedImageIndex = + nextIndex > lastIndex ? 0 : nextIndex; }, prevStagingAreaImage: (state) => { if (!state.layerState.stagingArea.images.length) { return; } - const currentIndex = state.layerState.stagingArea.selectedImageIndex; + const prevIndex = state.layerState.stagingArea.selectedImageIndex - 1; + const lastIndex = state.layerState.stagingArea.images.length - 1; - state.layerState.stagingArea.selectedImageIndex = Math.max( - currentIndex - 1, - 0 - ); + state.layerState.stagingArea.selectedImageIndex = + prevIndex < 0 ? lastIndex : prevIndex; }, commitStagingAreaImage: (state) => { if (!state.layerState.stagingArea.images.length) { @@ -657,13 +658,12 @@ export const canvasSlice = createSlice({ ...imageToCommit, }); } - state.layerState.stagingArea = { - ...initialLayerState.stagingArea, - }; + state.layerState.stagingArea = cloneDeep(initialLayerState).stagingArea; state.futureLayerStates = []; state.shouldShowStagingOutline = true; state.shouldShowStagingImage = true; + state.batchIds = []; }, fitBoundingBoxToStage: (state) => { const { @@ -786,11 +786,6 @@ export const canvasSlice = createSlice({ }, }, extraReducers: (builder) => { - builder.addCase(sessionCanceled.pending, (state) => { - if (!state.layerState.stagingArea.images.length) { - state.layerState.stagingArea = initialLayerState.stagingArea; - } - }); builder.addCase(setAspectRatio, (state, action) => { const ratio = action.payload; if (ratio) { diff --git a/invokeai/frontend/web/src/features/canvas/util/getBaseLayerBlob.ts b/invokeai/frontend/web/src/features/canvas/util/getBaseLayerBlob.ts index 3667acc79b..b67789e07e 100644 --- a/invokeai/frontend/web/src/features/canvas/util/getBaseLayerBlob.ts +++ b/invokeai/frontend/web/src/features/canvas/util/getBaseLayerBlob.ts @@ -1,11 +1,14 @@ -import { getCanvasBaseLayer } from './konvaInstanceProvider'; import { RootState } from 'app/store/store'; +import { getCanvasBaseLayer } from './konvaInstanceProvider'; import { konvaNodeToBlob } from './konvaNodeToBlob'; /** * Get the canvas base layer blob, with or without bounding box according to `shouldCropToBoundingBoxOnSave` */ -export const getBaseLayerBlob = async (state: RootState) => { +export const getBaseLayerBlob = async ( + state: RootState, + alwaysUseBoundingBox: boolean = false +) => { const canvasBaseLayer = getCanvasBaseLayer(); if (!canvasBaseLayer) { @@ -24,14 +27,15 @@ export const getBaseLayerBlob = async (state: RootState) => { const absPos = clonedBaseLayer.getAbsolutePosition(); - const boundingBox = shouldCropToBoundingBoxOnSave - ? { - x: boundingBoxCoordinates.x + absPos.x, - y: boundingBoxCoordinates.y + absPos.y, - width: boundingBoxDimensions.width, - height: boundingBoxDimensions.height, - } - : clonedBaseLayer.getClientRect(); + const boundingBox = + shouldCropToBoundingBoxOnSave || alwaysUseBoundingBox + ? { + x: boundingBoxCoordinates.x + absPos.x, + y: boundingBoxCoordinates.y + absPos.y, + width: boundingBoxDimensions.width, + height: boundingBoxDimensions.height, + } + : clonedBaseLayer.getClientRect(); return konvaNodeToBlob(clonedBaseLayer, boundingBox); }; diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index 70c459f0a4..f0745eae2b 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -6,7 +6,6 @@ import { import { cloneDeep, forEach } from 'lodash-es'; import { imagesApi } from 'services/api/endpoints/images'; import { components } from 'services/api/schema'; -import { isAnySessionRejected } from 'services/api/thunks/session'; import { ImageDTO } from 'services/api/types'; import { appSocketInvocationError } from 'services/events/actions'; import { controlNetImageProcessed } from './actions'; @@ -99,6 +98,9 @@ export const controlNetSlice = createSlice({ isControlNetEnabledToggled: (state) => { state.isEnabled = !state.isEnabled; }, + controlNetEnabled: (state) => { + state.isEnabled = true; + }, controlNetAdded: ( state, action: PayloadAction<{ @@ -112,6 +114,12 @@ export const controlNetSlice = createSlice({ controlNetId, }; }, + controlNetRecalled: (state, action: PayloadAction) => { + const controlNet = action.payload; + state.controlNets[controlNet.controlNetId] = { + ...controlNet, + }; + }, controlNetDuplicated: ( state, action: PayloadAction<{ @@ -418,10 +426,6 @@ export const controlNetSlice = createSlice({ state.pendingControlImages = []; }); - builder.addMatcher(isAnySessionRejected, (state) => { - state.pendingControlImages = []; - }); - builder.addMatcher( imagesApi.endpoints.deleteImage.matchFulfilled, (state, action) => { @@ -444,7 +448,9 @@ export const controlNetSlice = createSlice({ export const { isControlNetEnabledToggled, + controlNetEnabled, controlNetAdded, + controlNetRecalled, controlNetDuplicated, controlNetAddedFromImage, controlNetRemoved, diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx index 1bb6816bd9..104512a9c6 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx @@ -93,7 +93,7 @@ const GalleryBoard = ({ const [localBoardName, setLocalBoardName] = useState(board_name); const handleSelectBoard = useCallback(() => { - dispatch(boardIdSelected(board_id)); + dispatch(boardIdSelected({ boardId: board_id })); if (autoAssignBoardOnClick) { dispatch(autoAddBoardIdChanged(board_id)); } diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx index 55034decf0..6cea7d3eac 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx @@ -34,7 +34,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => { const { autoAddBoardId, autoAssignBoardOnClick } = useAppSelector(selector); const boardName = useBoardName('none'); const handleSelectBoard = useCallback(() => { - dispatch(boardIdSelected('none')); + dispatch(boardIdSelected({ boardId: 'none' })); if (autoAssignBoardOnClick) { dispatch(autoAddBoardIdChanged('none')); } diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/SystemBoardButton.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/SystemBoardButton.tsx index b538eee9d1..462aa4b5e6 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/SystemBoardButton.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/SystemBoardButton.tsx @@ -32,7 +32,7 @@ const SystemBoardButton = ({ board_id }: Props) => { const boardName = useBoardName(board_id); const handleClick = useCallback(() => { - dispatch(boardIdSelected(board_id)); + dispatch(boardIdSelected({ boardId: board_id })); }, [board_id, dispatch]); return ( diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 955e8a5a3a..25d8e1e5ac 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -1,8 +1,15 @@ -import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; +import { + ControlNetMetadataItem, + CoreMetadata, + LoRAMetadataItem, +} from 'features/nodes/types/types'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; -import { memo, useCallback } from 'react'; +import { memo, useMemo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas'; +import { + isValidControlNetModel, + isValidLoRAModel, +} from '../../../parameters/types/parameterSchemas'; import ImageMetadataItem from './ImageMetadataItem'; type Props = { @@ -26,6 +33,7 @@ const ImageMetadataActions = (props: Props) => { recallHeight, recallStrength, recallLoRA, + recallControlNet, } = useRecallParameters(); const handleRecallPositivePrompt = useCallback(() => { @@ -75,6 +83,21 @@ const ImageMetadataActions = (props: Props) => { [recallLoRA] ); + const handleRecallControlNet = useCallback( + (controlnet: ControlNetMetadataItem) => { + recallControlNet(controlnet); + }, + [recallControlNet] + ); + + const validControlNets: ControlNetMetadataItem[] = useMemo(() => { + return metadata?.controlnets + ? metadata.controlnets.filter((controlnet) => + isValidControlNetModel(controlnet.control_model) + ) + : []; + }, [metadata?.controlnets]); + if (!metadata || Object.keys(metadata).length === 0) { return null; } @@ -180,6 +203,14 @@ const ImageMetadataActions = (props: Props) => { ); } })} + {validControlNets.map((controlnet, index) => ( + handleRecallControlNet(controlnet)} + /> + ))} ); }; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index a4e4b02937..c78b22dd78 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -35,8 +35,11 @@ export const gallerySlice = createSlice({ autoAssignBoardOnClickChanged: (state, action: PayloadAction) => { state.autoAssignBoardOnClick = action.payload; }, - boardIdSelected: (state, action: PayloadAction) => { - state.selectedBoardId = action.payload; + boardIdSelected: ( + state, + action: PayloadAction<{ boardId: BoardId; selectedImageName?: string }> + ) => { + state.selectedBoardId = action.payload.boardId; state.galleryView = 'images'; }, autoAddBoardIdChanged: (state, action: PayloadAction) => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 57e5825fb9..e2ff7c5bb0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -12,6 +12,7 @@ import { OnConnect, OnConnectEnd, OnConnectStart, + OnEdgeUpdateFunc, OnEdgesChange, OnEdgesDelete, OnInit, @@ -21,6 +22,7 @@ import { OnSelectionChangeFunc, ProOptions, ReactFlow, + ReactFlowProps, XYPosition, } from 'reactflow'; import { useIsValidConnection } from '../../hooks/useIsValidConnection'; @@ -28,6 +30,8 @@ import { connectionEnded, connectionMade, connectionStarted, + edgeAdded, + edgeDeleted, edgesChanged, edgesDeleted, nodesChanged, @@ -167,6 +171,63 @@ export const Flow = () => { } }, []); + // #region Updatable Edges + + /** + * Adapted from https://reactflow.dev/docs/examples/edges/updatable-edge/ + * and https://reactflow.dev/docs/examples/edges/delete-edge-on-drop/ + * + * - Edges can be dragged from one handle to another. + * - If the user drags the edge away from the node and drops it, delete the edge. + * - Do not delete the edge if the cursor didn't move (resolves annoying behaviour + * where the edge is deleted if you click it accidentally). + */ + + // We have a ref for cursor position, but it is the *projected* cursor position. + // Easiest to just keep track of the last mouse event for this particular feature + const edgeUpdateMouseEvent = useRef(); + + const onEdgeUpdateStart: NonNullable = + useCallback( + (e, edge, _handleType) => { + // update mouse event + edgeUpdateMouseEvent.current = e; + // always delete the edge when starting an updated + dispatch(edgeDeleted(edge.id)); + }, + [dispatch] + ); + + const onEdgeUpdate: OnEdgeUpdateFunc = useCallback( + (_oldEdge, newConnection) => { + // instead of updating the edge (we deleted it earlier), we instead create + // a new one. + dispatch(connectionMade(newConnection)); + }, + [dispatch] + ); + + const onEdgeUpdateEnd: NonNullable = + useCallback( + (e, edge, _handleType) => { + // Handle the case where user begins a drag but didn't move the cursor - + // bc we deleted the edge, we need to add it back + if ( + // ignore touch events + !('touches' in e) && + edgeUpdateMouseEvent.current?.clientX === e.clientX && + edgeUpdateMouseEvent.current?.clientY === e.clientY + ) { + dispatch(edgeAdded(edge)); + } + // reset mouse event + edgeUpdateMouseEvent.current = undefined; + }, + [dispatch] + ); + + // #endregion + useHotkeys(['Ctrl+c', 'Meta+c'], (e) => { e.preventDefault(); dispatch(selectionCopied()); @@ -196,6 +257,9 @@ export const Flow = () => { onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} onEdgesDelete={onEdgesDelete} + onEdgeUpdate={onEdgeUpdate} + onEdgeUpdateStart={onEdgeUpdateStart} + onEdgeUpdateEnd={onEdgeUpdateEnd} onNodesDelete={onNodesDelete} onConnectStart={onConnectStart} onConnect={onConnect} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx index d2e0667ab2..a33a854c3b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx @@ -8,6 +8,7 @@ import InvocationNodeFooter from './InvocationNodeFooter'; import InvocationNodeHeader from './InvocationNodeHeader'; import InputField from './fields/InputField'; import OutputField from './fields/OutputField'; +import { useWithFooter } from 'features/nodes/hooks/useWithFooter'; type Props = { nodeId: string; @@ -20,6 +21,7 @@ type Props = { const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId); const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId); + const withFooter = useWithFooter(nodeId); const outputFieldNames = useOutputFieldNames(nodeId); return ( @@ -41,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { h: 'full', py: 2, gap: 1, - borderBottomRadius: 0, + borderBottomRadius: withFooter ? 0 : 'base', }} > @@ -74,7 +76,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => { ))} - + {withFooter && } )} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx index ba1f7977ab..ec5085221e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx @@ -5,6 +5,7 @@ import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox'; import SaveToGalleryCheckbox from './SaveToGalleryCheckbox'; import UseCacheCheckbox from './UseCacheCheckbox'; import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput'; +import { useFeatureStatus } from '../../../../../system/hooks/useFeatureStatus'; type Props = { nodeId: string; @@ -12,6 +13,7 @@ type Props = { const InvocationNodeFooter = ({ nodeId }: Props) => { const hasImageOutput = useHasImageOutput(nodeId); + const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled; return ( { justifyContent: 'space-between', }} > - + {isCacheEnabled && } {hasImageOutput && } {hasImageOutput && } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 0439445c24..a57787556c 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -53,13 +53,12 @@ export const useIsValidConnection = () => { } if ( - edges - .filter((edge) => { - return edge.target === target && edge.targetHandle === targetHandle; - }) - .find((edge) => { - edge.source === source && edge.sourceHandle === sourceHandle; - }) + edges.find((edge) => { + edge.target === target && + edge.targetHandle === targetHandle && + edge.source === source && + edge.sourceHandle === sourceHandle; + }) ) { // We already have a connection from this source to this target return false; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts index 57941eaec8..4d2a58cc35 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts @@ -1,31 +1,14 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { some } from 'lodash-es'; +import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useMemo } from 'react'; -import { FOOTER_FIELDS } from '../types/constants'; -import { isInvocationNode } from '../types/types'; +import { useHasImageOutput } from './useHasImageOutput'; -export const useHasImageOutputs = (nodeId: string) => { - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ nodes }) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return some(node.data.outputs, (output) => - FOOTER_FIELDS.includes(output.type) - ); - }, - defaultSelectorOptions - ), - [nodeId] +export const useWithFooter = (nodeId: string) => { + const hasImageOutput = useHasImageOutput(nodeId); + const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled; + + const withFooter = useMemo( + () => hasImageOutput || isCacheEnabled, + [hasImageOutput, isCacheEnabled] ); - - const withFooter = useAppSelector(selector); return withFooter; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 01de3de883..1b3a5ca929 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -15,6 +15,7 @@ import { NodeChange, OnConnectStartParams, SelectionMode, + updateEdge, Viewport, XYPosition, } from 'reactflow'; @@ -182,6 +183,16 @@ const nodesSlice = createSlice({ edgesChanged: (state, action: PayloadAction) => { state.edges = applyEdgeChanges(action.payload, state.edges); }, + edgeAdded: (state, action: PayloadAction) => { + state.edges = addEdge(action.payload, state.edges); + }, + edgeUpdated: ( + state, + action: PayloadAction<{ oldEdge: Edge; newConnection: Connection }> + ) => { + const { oldEdge, newConnection } = action.payload; + state.edges = updateEdge(oldEdge, newConnection, state.edges); + }, connectionStarted: (state, action: PayloadAction) => { state.connectionStartParams = action.payload; const { nodeId, handleId, handleType } = action.payload; @@ -366,6 +377,7 @@ const nodesSlice = createSlice({ target: edge.target, type: 'collapsed', data: { count: 1 }, + updatable: false, }); } } @@ -388,6 +400,7 @@ const nodesSlice = createSlice({ target: edge.target, type: 'collapsed', data: { count: 1 }, + updatable: false, }); } } @@ -400,6 +413,9 @@ const nodesSlice = createSlice({ } } }, + edgeDeleted: (state, action: PayloadAction) => { + state.edges = state.edges.filter((e) => e.id !== action.payload); + }, edgesDeleted: (state, action: PayloadAction) => { const edges = action.payload; const collapsedEdges = edges.filter((e) => e.type === 'collapsed'); @@ -890,69 +906,72 @@ const nodesSlice = createSlice({ }); export const { - nodesChanged, - edgesChanged, - nodeAdded, - nodesDeleted, + addNodePopoverClosed, + addNodePopoverOpened, + addNodePopoverToggled, + connectionEnded, connectionMade, connectionStarted, - connectionEnded, - shouldShowFieldTypeLegendChanged, - shouldShowMinimapPanelChanged, - nodeTemplatesBuilt, - nodeEditorReset, - imageCollectionFieldValueChanged, - fieldStringValueChanged, - fieldNumberValueChanged, + edgeDeleted, + edgesChanged, + edgesDeleted, + edgeUpdated, fieldBoardValueChanged, fieldBooleanValueChanged, - fieldImageValueChanged, fieldColorValueChanged, - fieldMainModelValueChanged, - fieldVaeModelValueChanged, - fieldLoRAModelValueChanged, - fieldEnumModelValueChanged, fieldControlNetModelValueChanged, + fieldEnumModelValueChanged, + fieldImageValueChanged, fieldIPAdapterModelValueChanged, + fieldLabelChanged, + fieldLoRAModelValueChanged, + fieldMainModelValueChanged, + fieldNumberValueChanged, fieldRefinerModelValueChanged, fieldSchedulerValueChanged, + fieldStringValueChanged, + fieldVaeModelValueChanged, + imageCollectionFieldValueChanged, + mouseOverFieldChanged, + mouseOverNodeChanged, + nodeAdded, + nodeEditorReset, + nodeEmbedWorkflowChanged, + nodeExclusivelySelected, + nodeIsIntermediateChanged, nodeIsOpenChanged, nodeLabelChanged, nodeNotesChanged, - edgesDeleted, - shouldValidateGraphChanged, - shouldAnimateEdgesChanged, nodeOpacityChanged, - shouldSnapToGridChanged, - shouldColorEdgesChanged, - selectedNodesChanged, - selectedEdgesChanged, - workflowNameChanged, - workflowDescriptionChanged, - workflowTagsChanged, - workflowAuthorChanged, - workflowNotesChanged, - workflowVersionChanged, - workflowContactChanged, - workflowLoaded, + nodesChanged, + nodesDeleted, + nodeTemplatesBuilt, + nodeUseCacheChanged, notesNodeValueChanged, + selectedAll, + selectedEdgesChanged, + selectedNodesChanged, + selectionCopied, + selectionModeChanged, + selectionPasted, + shouldAnimateEdgesChanged, + shouldColorEdgesChanged, + shouldShowFieldTypeLegendChanged, + shouldShowMinimapPanelChanged, + shouldSnapToGridChanged, + shouldValidateGraphChanged, + viewportChanged, + workflowAuthorChanged, + workflowContactChanged, + workflowDescriptionChanged, workflowExposedFieldAdded, workflowExposedFieldRemoved, - fieldLabelChanged, - viewportChanged, - mouseOverFieldChanged, - selectionCopied, - selectionPasted, - selectedAll, - addNodePopoverOpened, - addNodePopoverClosed, - addNodePopoverToggled, - selectionModeChanged, - nodeEmbedWorkflowChanged, - nodeIsIntermediateChanged, - mouseOverNodeChanged, - nodeExclusivelySelected, - nodeUseCacheChanged, + workflowLoaded, + workflowNameChanged, + workflowNotesChanged, + workflowTagsChanged, + workflowVersionChanged, + edgeAdded, } = nodesSlice.actions; export default nodesSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 1be2d579d8..6343240a88 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -55,9 +55,29 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.cannotConnectInputToInput'); } + // we have to figure out which is the target and which is the source + const target = handleType === 'target' ? nodeId : connectionNodeId; + const targetHandle = + handleType === 'target' ? fieldName : connectionFieldName; + const source = handleType === 'source' ? nodeId : connectionNodeId; + const sourceHandle = + handleType === 'source' ? fieldName : connectionFieldName; + if ( edges.find((edge) => { - return edge.target === nodeId && edge.targetHandle === fieldName; + edge.target === target && + edge.targetHandle === targetHandle && + edge.source === source && + edge.sourceHandle === sourceHandle; + }) + ) { + // We already have a connection from this source to this target + return i18n.t('nodes.cannotDuplicateConnection'); + } + + if ( + edges.find((edge) => { + return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples targetType !== 'CollectionItem' diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index fc8fe10ccc..eb8baf513e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -1141,6 +1141,10 @@ const zLoRAMetadataItem = z.object({ export type LoRAMetadataItem = z.infer; +const zControlNetMetadataItem = zControlField.deepPartial(); + +export type ControlNetMetadataItem = z.infer; + export const zCoreMetadata = z .object({ app_version: z.string().nullish().catch(null), @@ -1222,6 +1226,7 @@ export const zInvocationNodeData = z.object({ notes: z.string(), embedWorkflow: z.boolean(), isIntermediate: z.boolean(), + useCache: z.boolean().optional(), version: zSemVer.optional(), }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts index 6bd44db197..a6ee6a091d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts @@ -32,7 +32,8 @@ export const addSDXLRefinerToGraph = ( graph: NonNullableGraph, baseNodeId: string, modelLoaderNodeId?: string, - canvasInitImage?: ImageDTO + canvasInitImage?: ImageDTO, + canvasMaskImage?: ImageDTO ): void => { const { refinerModel, @@ -257,8 +258,30 @@ export const addSDXLRefinerToGraph = ( }; } - graph.edges.push( - { + if (graph.id === SDXL_CANVAS_INPAINT_GRAPH) { + if (isUsingScaledDimensions) { + graph.edges.push({ + source: { + node_id: MASK_RESIZE_UP, + field: 'image', + }, + destination: { + node_id: SDXL_REFINER_INPAINT_CREATE_MASK, + field: 'mask', + }, + }); + } else { + graph.nodes[SDXL_REFINER_INPAINT_CREATE_MASK] = { + ...(graph.nodes[ + SDXL_REFINER_INPAINT_CREATE_MASK + ] as CreateDenoiseMaskInvocation), + mask: canvasMaskImage, + }; + } + } + + if (graph.id === SDXL_CANVAS_OUTPAINT_GRAPH) { + graph.edges.push({ source: { node_id: isUsingScaledDimensions ? MASK_RESIZE_UP : MASK_COMBINE, field: 'image', @@ -267,18 +290,19 @@ export const addSDXLRefinerToGraph = ( node_id: SDXL_REFINER_INPAINT_CREATE_MASK, field: 'mask', }, + }); + } + + graph.edges.push({ + source: { + node_id: SDXL_REFINER_INPAINT_CREATE_MASK, + field: 'denoise_mask', }, - { - source: { - node_id: SDXL_REFINER_INPAINT_CREATE_MASK, - field: 'denoise_mask', - }, - destination: { - node_id: SDXL_REFINER_DENOISE_LATENTS, - field: 'denoise_mask', - }, - } - ); + destination: { + node_id: SDXL_REFINER_DENOISE_LATENTS, + field: 'denoise_mask', + }, + }); } if ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts index 389d510ac7..a245953c8e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts @@ -663,7 +663,8 @@ export const buildCanvasSDXLInpaintGraph = ( graph, CANVAS_COHERENCE_DENOISE_LATENTS, modelLoaderNodeId, - canvasInitImage + canvasInitImage, + canvasMaskImage ); if (seamlessXAxis || seamlessYAxis) { modelLoaderNodeId = SDXL_REFINER_SEAMLESS; diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 4fb9a0ce2c..d8561ab122 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -2,7 +2,11 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppToaster } from 'app/components/Toaster'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types'; +import { + CoreMetadata, + LoRAMetadataItem, + ControlNetMetadataItem, +} from 'features/nodes/types/types'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -18,9 +22,18 @@ import { useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { ImageDTO } from 'services/api/types'; import { + controlNetModelsAdapter, loraModelsAdapter, + useGetControlNetModelsQuery, useGetLoRAModelsQuery, } from '../../../services/api/endpoints/models'; +import { + ControlNetConfig, + controlNetEnabled, + controlNetRecalled, + controlNetReset, + initialControlNet, +} from '../../controlNet/store/controlNetSlice'; import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice'; import { initialImageSelected, modelSelected } from '../store/actions'; import { @@ -38,6 +51,7 @@ import { isValidCfgScale, isValidHeight, isValidLoRAModel, + isValidControlNetModel, isValidMainModel, isValidNegativePrompt, isValidPositivePrompt, @@ -53,6 +67,11 @@ import { isValidStrength, isValidWidth, } from '../types/parameterSchemas'; +import { v4 as uuidv4 } from 'uuid'; +import { + CONTROLNET_PROCESSORS, + CONTROLNET_MODEL_DEFAULT_PROCESSORS, +} from 'features/controlNet/store/constants'; const selector = createSelector(stateSelector, ({ generation }) => { const { model } = generation; @@ -390,6 +409,121 @@ export const useRecallParameters = () => { [prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast] ); + /** + * Recall ControlNet with toast + */ + + const { controlnets } = useGetControlNetModelsQuery(undefined, { + selectFromResult: (result) => ({ + controlnets: result.data + ? controlNetModelsAdapter.getSelectors().selectAll(result.data) + : [], + }), + }); + + const prepareControlNetMetadataItem = useCallback( + (controlnetMetadataItem: ControlNetMetadataItem) => { + if (!isValidControlNetModel(controlnetMetadataItem.control_model)) { + return { controlnet: null, error: 'Invalid ControlNet model' }; + } + + const { + image, + control_model, + control_weight, + begin_step_percent, + end_step_percent, + control_mode, + resize_mode, + } = controlnetMetadataItem; + + const matchingControlNetModel = controlnets.find( + (c) => + c.base_model === control_model.base_model && + c.model_name === control_model.model_name + ); + + if (!matchingControlNetModel) { + return { controlnet: null, error: 'ControlNet model is not installed' }; + } + + const isCompatibleBaseModel = + matchingControlNetModel?.base_model === model?.base_model; + + if (!isCompatibleBaseModel) { + return { + controlnet: null, + error: 'ControlNet incompatible with currently-selected model', + }; + } + + const controlNetId = uuidv4(); + + let processorType = initialControlNet.processorType; + for (const modelSubstring in CONTROLNET_MODEL_DEFAULT_PROCESSORS) { + if (matchingControlNetModel.model_name.includes(modelSubstring)) { + processorType = + CONTROLNET_MODEL_DEFAULT_PROCESSORS[modelSubstring] || + initialControlNet.processorType; + break; + } + } + const processorNode = CONTROLNET_PROCESSORS[processorType].default; + + const controlnet: ControlNetConfig = { + isEnabled: true, + model: matchingControlNetModel, + weight: + typeof control_weight === 'number' + ? control_weight + : initialControlNet.weight, + beginStepPct: begin_step_percent || initialControlNet.beginStepPct, + endStepPct: end_step_percent || initialControlNet.endStepPct, + controlMode: control_mode || initialControlNet.controlMode, + resizeMode: resize_mode || initialControlNet.resizeMode, + controlImage: image?.image_name || null, + processedControlImage: image?.image_name || null, + processorType, + processorNode: + processorNode.type !== 'none' + ? processorNode + : initialControlNet.processorNode, + shouldAutoConfig: true, + controlNetId, + }; + + return { controlnet, error: null }; + }, + [controlnets, model?.base_model] + ); + + const recallControlNet = useCallback( + (controlnetMetadataItem: ControlNetMetadataItem) => { + const result = prepareControlNetMetadataItem(controlnetMetadataItem); + + if (!result.controlnet) { + parameterNotSetToast(result.error); + return; + } + + dispatch( + controlNetRecalled({ + ...result.controlnet, + }) + ); + + dispatch(controlNetEnabled()); + + parameterSetToast(); + }, + [ + prepareControlNetMetadataItem, + dispatch, + parameterSetToast, + parameterNotSetToast, + ] + ); + /* * Sets image as initial image with toast */ @@ -428,6 +562,7 @@ export const useRecallParameters = () => { refiner_negative_aesthetic_score, refiner_start, loras, + controlnets, } = metadata; if (isValidCfgScale(cfg_scale)) { @@ -517,6 +652,15 @@ export const useRecallParameters = () => { } }); + dispatch(controlNetReset()); + dispatch(controlNetEnabled()); + controlnets?.forEach((controlnet) => { + const result = prepareControlNetMetadataItem(controlnet); + if (result.controlnet) { + dispatch(controlNetRecalled(result.controlnet)); + } + }); + allParameterSetToast(); }, [ @@ -524,6 +668,7 @@ export const useRecallParameters = () => { allParameterSetToast, dispatch, prepareLoRAMetadataItem, + prepareControlNetMetadataItem, ] ); @@ -542,6 +687,7 @@ export const useRecallParameters = () => { recallHeight, recallStrength, recallLoRA, + recallControlNet, recallAllParameters, sendToImageToImage, }; diff --git a/invokeai/frontend/web/src/features/queue/components/InvocationCacheStatus.tsx b/invokeai/frontend/web/src/features/queue/components/InvocationCacheStatus.tsx index 423ab09376..1720f81285 100644 --- a/invokeai/frontend/web/src/features/queue/components/InvocationCacheStatus.tsx +++ b/invokeai/frontend/web/src/features/queue/components/InvocationCacheStatus.tsx @@ -1,9 +1,7 @@ import { ButtonGroup } from '@chakra-ui/react'; -import { useAppSelector } from 'app/store/storeHooks'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo'; -import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; import ClearInvocationCacheButton from './ClearInvocationCacheButton'; import ToggleInvocationCacheButton from './ToggleInvocationCacheButton'; import StatusStatGroup from './common/StatusStatGroup'; @@ -11,16 +9,7 @@ import StatusStatItem from './common/StatusStatItem'; const InvocationCacheStatus = () => { const { t } = useTranslation(); - const isConnected = useAppSelector((state) => state.system.isConnected); - const { data: queueStatus } = useGetQueueStatusQuery(undefined); - const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined, { - pollingInterval: - isConnected && - queueStatus?.processor.is_started && - queueStatus?.queue.pending > 0 - ? 5000 - : 0, - }); + const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined); return ( diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemSkeleton.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemSkeleton.tsx index 72a5fcdc96..529c46af74 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemSkeleton.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueItemSkeleton.tsx @@ -1,46 +1,37 @@ -import { Flex, Skeleton, Text } from '@chakra-ui/react'; +import { Flex, Skeleton } from '@chakra-ui/react'; import { memo } from 'react'; import { COLUMN_WIDTHS } from './constants'; const QueueItemSkeleton = () => { return ( - + - -   + +   - - -   + + +   - - -   + + +   - - -   + + +   - - -   + + +   diff --git a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueList.tsx b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueList.tsx index 19c61b4379..e136e6df6c 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueList/QueueList.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueList/QueueList.tsx @@ -3,6 +3,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback'; import { listCursorChanged, listPriorityChanged, @@ -23,7 +24,6 @@ import QueueItemComponent from './QueueItemComponent'; import QueueListComponent from './QueueListComponent'; import QueueListHeader from './QueueListHeader'; import { ListContext } from './types'; -import QueueItemSkeleton from './QueueItemSkeleton'; // eslint-disable-next-line @typescript-eslint/no-explicit-any type TableVirtuosoScrollerRef = (ref: HTMLElement | Window | null) => any; @@ -126,54 +126,40 @@ const QueueList = () => { [openQueueItems, toggleQueueItem] ); + if (isLoading) { + return ; + } + + if (!queueItems.length) { + return ( + + + {t('queue.queueEmpty')} + + + ); + } + return ( - {isLoading ? ( - <> - - - - - - - - - - - - - ) : ( - <> - {queueItems.length ? ( - <> - - - - data={queueItems} - endReached={handleLoadMore} - scrollerRef={setScroller as TableVirtuosoScrollerRef} - itemContent={itemContent} - computeItemKey={computeItemKey} - components={components} - context={context} - /> - - - ) : ( - - - {t('queue.queueEmpty')} - - - )} - - )} + + + + data={queueItems} + endReached={handleLoadMore} + scrollerRef={setScroller as TableVirtuosoScrollerRef} + itemContent={itemContent} + computeItemKey={computeItemKey} + components={components} + context={context} + /> + ); }; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts index 308695dd67..1b07221a74 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useCancelCurrentQueueItem.ts @@ -1,5 +1,6 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { addToast } from 'features/system/store/systemSlice'; +import { isNil } from 'lodash-es'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { @@ -40,7 +41,7 @@ export const useCancelCurrentQueueItem = () => { }, [currentQueueItemId, dispatch, t, trigger]); const isDisabled = useMemo( - () => !isConnected || !currentQueueItemId, + () => !isConnected || isNil(currentQueueItemId), [isConnected, currentQueueItemId] ); diff --git a/invokeai/frontend/web/src/features/system/store/systemSlice.ts b/invokeai/frontend/web/src/features/system/store/systemSlice.ts index 7d31838afd..9a110f5f23 100644 --- a/invokeai/frontend/web/src/features/system/store/systemSlice.ts +++ b/invokeai/frontend/web/src/features/system/store/systemSlice.ts @@ -1,9 +1,8 @@ import { UseToastOptions } from '@chakra-ui/react'; import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { t } from 'i18next'; -import { get, startCase, truncate, upperFirst } from 'lodash-es'; +import { startCase } from 'lodash-es'; import { LogLevelName } from 'roarr'; -import { isAnySessionRejected } from 'services/api/thunks/session'; import { appSocketConnected, appSocketDisconnected, @@ -20,8 +19,7 @@ import { } from 'services/events/actions'; import { calculateStepPercentage } from '../util/calculateStepPercentage'; import { makeToast } from '../util/makeToast'; -import { SystemState, LANGUAGES } from './types'; -import { zPydanticValidationError } from './zodSchemas'; +import { LANGUAGES, SystemState } from './types'; export const initialSystemState: SystemState = { isInitialized: false, @@ -175,50 +173,6 @@ export const systemSlice = createSlice({ // *** Matchers - must be after all cases *** - /** - * Session Invoked - REJECTED - * Session Created - REJECTED - */ - builder.addMatcher(isAnySessionRejected, (state, action) => { - let errorDescription = undefined; - const duration = 5000; - - if (action.payload?.status === 422) { - const result = zPydanticValidationError.safeParse(action.payload); - if (result.success) { - result.data.error.detail.map((e) => { - state.toastQueue.push( - makeToast({ - title: truncate(upperFirst(e.msg), { length: 128 }), - status: 'error', - description: truncate( - `Path: - ${e.loc.join('.')}`, - { length: 128 } - ), - duration, - }) - ); - }); - return; - } - } else if (action.payload?.error) { - errorDescription = action.payload?.error; - } - - state.toastQueue.push( - makeToast({ - title: t('toast.serverError'), - status: 'error', - description: truncate( - get(errorDescription, 'detail', 'Unknown Error'), - { length: 128 } - ), - duration, - }) - ); - }); - /** * Any server error */ diff --git a/invokeai/frontend/web/src/features/system/store/zodSchemas.ts b/invokeai/frontend/web/src/features/system/store/zodSchemas.ts index 3a3b950019..9d66f5ae88 100644 --- a/invokeai/frontend/web/src/features/system/store/zodSchemas.ts +++ b/invokeai/frontend/web/src/features/system/store/zodSchemas.ts @@ -2,7 +2,7 @@ import { z } from 'zod'; export const zPydanticValidationError = z.object({ status: z.literal(422), - error: z.object({ + data: z.object({ detail: z.array( z.object({ loc: z.array(z.string()), diff --git a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx index fb5756b121..ac7b8aa1c4 100644 --- a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx +++ b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx @@ -14,7 +14,7 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent'; import NodeEditorPanelGroup from 'features/nodes/components/sidePanel/NodeEditorPanelGroup'; -import { InvokeTabName, tabMap } from 'features/ui/store/tabMap'; +import { InvokeTabName } from 'features/ui/store/tabMap'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { ResourceKey } from 'i18next'; import { isEqual } from 'lodash-es'; @@ -110,7 +110,7 @@ export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager', 'queue']; export const NO_SIDE_PANEL_TABS: InvokeTabName[] = ['modelManager', 'queue']; const InvokeTabs = () => { - const activeTab = useAppSelector(activeTabIndexSelector); + const activeTabIndex = useAppSelector(activeTabIndexSelector); const activeTabName = useAppSelector(activeTabNameSelector); const enabledTabs = useAppSelector(enabledTabsSelector); const { t } = useTranslation(); @@ -150,13 +150,13 @@ const InvokeTabs = () => { const handleTabChange = useCallback( (index: number) => { - const activeTabName = tabMap[index]; - if (!activeTabName) { + const tab = enabledTabs[index]; + if (!tab) { return; } - dispatch(setActiveTab(activeTabName)); + dispatch(setActiveTab(tab.id)); }, - [dispatch] + [dispatch, enabledTabs] ); const { @@ -216,8 +216,8 @@ const InvokeTabs = () => { return ( {layer === 'base' && ( - dispatch(setBrushColor(newColor))} - /> + > + dispatch(setBrushColor(newColor))} + /> + )} {layer === 'mask' && ( - dispatch(setMaskColor(newColor))} - /> + > + dispatch(setMaskColor(newColor))} + /> + )} diff --git a/invokeai/frontend/web/src/features/ui/store/extraReducers.ts b/invokeai/frontend/web/src/features/ui/store/extraReducers.ts deleted file mode 100644 index 9b134e1476..0000000000 --- a/invokeai/frontend/web/src/features/ui/store/extraReducers.ts +++ /dev/null @@ -1,13 +0,0 @@ -import { InvokeTabName, tabMap } from './tabMap'; -import { UIState } from './uiTypes'; - -export const setActiveTabReducer = ( - state: UIState, - newActiveTab: number | InvokeTabName -) => { - if (typeof newActiveTab === 'number') { - state.activeTab = newActiveTab; - } else { - state.activeTab = tabMap.indexOf(newActiveTab); - } -}; diff --git a/invokeai/frontend/web/src/features/ui/store/uiSelectors.ts b/invokeai/frontend/web/src/features/ui/store/uiSelectors.ts index 5427fa9d3b..99ee8d80f7 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSelectors.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSelectors.ts @@ -1,27 +1,23 @@ import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; -import { isEqual } from 'lodash-es'; - -import { InvokeTabName, tabMap } from './tabMap'; -import { UIState } from './uiTypes'; +import { isEqual, isString } from 'lodash-es'; +import { tabMap } from './tabMap'; export const activeTabNameSelector = createSelector( - (state: RootState) => state.ui, - (ui: UIState) => tabMap[ui.activeTab] as InvokeTabName, - { - memoizeOptions: { - equalityCheck: isEqual, - }, - } + (state: RootState) => state, + /** + * Previously `activeTab` was an integer, but now it's a string. + * Default to first tab in case user has integer. + */ + ({ ui }) => (isString(ui.activeTab) ? ui.activeTab : 'txt2img') ); export const activeTabIndexSelector = createSelector( - (state: RootState) => state.ui, - (ui: UIState) => ui.activeTab, - { - memoizeOptions: { - equalityCheck: isEqual, - }, + (state: RootState) => state, + ({ ui, config }) => { + const tabs = tabMap.filter((t) => !config.disabledTabs.includes(t)); + const idx = tabs.indexOf(ui.activeTab); + return idx === -1 ? 0 : idx; } ); diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 82c9ef4e77..9782d0bfac 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -2,12 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; -import { setActiveTabReducer } from './extraReducers'; import { InvokeTabName } from './tabMap'; import { UIState } from './uiTypes'; export const initialUIState: UIState = { - activeTab: 0, + activeTab: 'txt2img', shouldShowImageDetails: false, shouldUseCanvasBetaLayout: false, shouldShowExistingModelsInSearch: false, @@ -26,7 +25,7 @@ export const uiSlice = createSlice({ initialState: initialUIState, reducers: { setActiveTab: (state, action: PayloadAction) => { - setActiveTabReducer(state, action.payload); + state.activeTab = action.payload; }, setShouldShowImageDetails: (state, action: PayloadAction) => { state.shouldShowImageDetails = action.payload; @@ -73,7 +72,7 @@ export const uiSlice = createSlice({ }, extraReducers(builder) { builder.addCase(initialImageChanged, (state) => { - setActiveTabReducer(state, 'img2img'); + state.activeTab = 'img2img'; }); }, }); diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index 41a359a651..1b9fee6989 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -1,4 +1,5 @@ import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; +import { InvokeTabName } from './tabMap'; export type Coordinates = { x: number; @@ -13,7 +14,7 @@ export type Dimensions = { export type Rect = Coordinates & Dimensions; export interface UIState { - activeTab: number; + activeTab: InvokeTabName; shouldShowImageDetails: boolean; shouldUseCanvasBetaLayout: boolean; shouldShowExistingModelsInSearch: boolean; diff --git a/invokeai/frontend/web/src/services/api/thunks/session.ts b/invokeai/frontend/web/src/services/api/thunks/session.ts deleted file mode 100644 index 837fd7a28e..0000000000 --- a/invokeai/frontend/web/src/services/api/thunks/session.ts +++ /dev/null @@ -1,184 +0,0 @@ -import { createAsyncThunk, isAnyOf } from '@reduxjs/toolkit'; -import { $queueId } from 'features/queue/store/queueNanoStore'; -import { isObject } from 'lodash-es'; -import { $client } from 'services/api/client'; -import { paths } from 'services/api/schema'; -import { O } from 'ts-toolbelt'; - -type CreateSessionArg = { - graph: NonNullable< - paths['/api/v1/sessions/']['post']['requestBody'] - >['content']['application/json']; -}; - -type CreateSessionResponse = O.Required< - NonNullable< - paths['/api/v1/sessions/']['post']['requestBody'] - >['content']['application/json'], - 'id' ->; - -type CreateSessionThunkConfig = { - rejectValue: { arg: CreateSessionArg; status: number; error: unknown }; -}; - -/** - * `SessionsService.createSession()` thunk - */ -export const sessionCreated = createAsyncThunk< - CreateSessionResponse, - CreateSessionArg, - CreateSessionThunkConfig ->('api/sessionCreated', async (arg, { rejectWithValue }) => { - const { graph } = arg; - const { POST } = $client.get(); - const { data, error, response } = await POST('/api/v1/sessions/', { - body: graph, - params: { query: { queue_id: $queueId.get() } }, - }); - - if (error) { - return rejectWithValue({ arg, status: response.status, error }); - } - - return data; -}); - -type InvokedSessionArg = { - session_id: paths['/api/v1/sessions/{session_id}/invoke']['put']['parameters']['path']['session_id']; -}; - -type InvokedSessionResponse = - paths['/api/v1/sessions/{session_id}/invoke']['put']['responses']['200']['content']['application/json']; - -type InvokedSessionThunkConfig = { - rejectValue: { - arg: InvokedSessionArg; - error: unknown; - status: number; - }; -}; - -const isErrorWithStatus = (error: unknown): error is { status: number } => - isObject(error) && 'status' in error; - -const isErrorWithDetail = (error: unknown): error is { detail: string } => - isObject(error) && 'detail' in error; - -/** - * `SessionsService.invokeSession()` thunk - */ -export const sessionInvoked = createAsyncThunk< - InvokedSessionResponse, - InvokedSessionArg, - InvokedSessionThunkConfig ->('api/sessionInvoked', async (arg, { rejectWithValue }) => { - const { session_id } = arg; - const { PUT } = $client.get(); - const { error, response } = await PUT( - '/api/v1/sessions/{session_id}/invoke', - { - params: { - query: { queue_id: $queueId.get(), all: true }, - path: { session_id }, - }, - } - ); - - if (error) { - if (isErrorWithStatus(error) && error.status === 403) { - return rejectWithValue({ - arg, - status: response.status, - // eslint-disable-next-line @typescript-eslint/no-explicit-any - error: (error as any).body.detail, - }); - } - if (isErrorWithDetail(error) && response.status === 403) { - return rejectWithValue({ - arg, - status: response.status, - error: error.detail, - }); - } - if (error) { - return rejectWithValue({ arg, status: response.status, error }); - } - } -}); - -type CancelSessionArg = - paths['/api/v1/sessions/{session_id}/invoke']['delete']['parameters']['path']; - -type CancelSessionResponse = - paths['/api/v1/sessions/{session_id}/invoke']['delete']['responses']['200']['content']['application/json']; - -type CancelSessionThunkConfig = { - rejectValue: { - arg: CancelSessionArg; - error: unknown; - }; -}; - -/** - * `SessionsService.cancelSession()` thunk - */ -export const sessionCanceled = createAsyncThunk< - CancelSessionResponse, - CancelSessionArg, - CancelSessionThunkConfig ->('api/sessionCanceled', async (arg, { rejectWithValue }) => { - const { session_id } = arg; - const { DELETE } = $client.get(); - const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', { - params: { - path: { session_id }, - }, - }); - - if (error) { - return rejectWithValue({ arg, error }); - } - - return data; -}); - -type ListSessionsArg = { - params: paths['/api/v1/sessions/']['get']['parameters']; -}; - -type ListSessionsResponse = - paths['/api/v1/sessions/']['get']['responses']['200']['content']['application/json']; - -type ListSessionsThunkConfig = { - rejectValue: { - arg: ListSessionsArg; - error: unknown; - }; -}; - -/** - * `SessionsService.listSessions()` thunk - */ -export const listedSessions = createAsyncThunk< - ListSessionsResponse, - ListSessionsArg, - ListSessionsThunkConfig ->('api/listSessions', async (arg, { rejectWithValue }) => { - const { params } = arg; - const { GET } = $client.get(); - const { data, error } = await GET('/api/v1/sessions/', { - params, - }); - - if (error) { - return rejectWithValue({ arg, error }); - } - - return data; -}); - -export const isAnySessionRejected = isAnyOf( - sessionCreated.rejected, - sessionInvoked.rejected -); diff --git a/invokeai/frontend/web/src/theme/theme.ts b/invokeai/frontend/web/src/theme/theme.ts index 3b83ea2393..ae38aefca0 100644 --- a/invokeai/frontend/web/src/theme/theme.ts +++ b/invokeai/frontend/web/src/theme/theme.ts @@ -1,5 +1,4 @@ -import { ThemeOverride } from '@chakra-ui/react'; - +import { ThemeOverride, ToastProviderProps } from '@chakra-ui/react'; import { InvokeAIColors } from './colors/colors'; import { accordionTheme } from './components/accordion'; import { buttonTheme } from './components/button'; @@ -149,3 +148,7 @@ export const theme: ThemeOverride = { Tooltip: tooltipTheme, }, }; + +export const TOAST_OPTIONS: ToastProviderProps = { + defaultOptions: { isClosable: true }, +};